diff --git a/GNUmakefile b/GNUmakefile index 604394716..74e07c265 100644 --- a/GNUmakefile +++ b/GNUmakefile @@ -103,8 +103,8 @@ abs_prefix := ${abspath ${prefix}} export CXX blas blas_int blas_threaded openmp static gpu_backend CXXFLAGS += -O3 -std=c++17 -Wall -Wshadow -pedantic -MMD -NVCCFLAGS += -O3 -std=c++11 --compiler-options '-Wall -Wno-unused-function' -HIPCCFLAGS += -std=c++14 -DTCE_HIP -fno-gpu-rdc +NVCCFLAGS += -O3 -std=c++17 --compiler-options '-Wall -Wno-unused-function' +HIPCCFLAGS += -std=c++17 -DTCE_HIP -fno-gpu-rdc force: ; @@ -480,6 +480,7 @@ ifneq ($(only_unit),1) src/internal/internal_getrf.cc \ src/internal/internal_getrf_nopiv.cc \ src/internal/internal_getrf_tntpiv.cc \ + src/internal/internal_getrf_addmod.cc \ src/internal/internal_hbnorm.cc \ src/internal/internal_hebr.cc \ src/internal/internal_hegst.cc \ @@ -501,6 +502,8 @@ ifneq ($(only_unit),1) src/internal/internal_trnorm.cc \ src/internal/internal_trsm.cc \ src/internal/internal_trsmA.cc \ + src/internal/internal_trsm_addmod.cc \ + src/internal/internal_trsmA_addmod.cc \ src/internal/internal_trtri.cc \ src/internal/internal_trtrm.cc \ src/internal/internal_ttlqt.cc \ @@ -530,6 +533,7 @@ cuda_src := \ src/cuda/device_synorm.cu \ src/cuda/device_transpose.cu \ src/cuda/device_trnorm.cu \ + src/cuda/device_trsm_addmod.cu \ src/cuda/device_tzadd.cu \ src/cuda/device_tzcopy.cu \ src/cuda/device_tzscale.cu \ @@ -599,13 +603,17 @@ ifneq ($(only_unit),1) src/gesv_mixed_gmres.cc \ src/gesv_nopiv.cc \ src/gesv_rbt.cc \ + src/gesv_addmod.cc \ + src/gesv_addmod_ir.cc \ src/getrf.cc \ src/getrf_nopiv.cc \ src/getrf_tntpiv.cc \ + src/getrf_addmod.cc \ src/getri.cc \ src/getriOOP.cc \ src/getrs.cc \ src/getrs_nopiv.cc \ + src/getrs_addmod.cc \ src/hb2st.cc \ src/hbmm.cc \ src/he2hb.cc \ @@ -657,6 +665,9 @@ ifneq ($(only_unit),1) src/trsm.cc \ src/trsmA.cc \ src/trsmB.cc \ + src/trsm_addmod.cc \ + src/trsmA_addmod.cc \ + src/trsmB_addmod.cc \ src/trtri.cc \ src/trtrm.cc \ src/unmlq.cc \ @@ -667,6 +678,8 @@ ifneq ($(only_unit),1) src/work/work_trmm.cc \ src/work/work_trsm.cc \ src/work/work_trsmA.cc \ + src/work/work_trsm_addmod.cc \ + src/work/work_trsmA_addmod.cc \ # End. Add alphabetically. endif diff --git a/include/slate/addmod.hh b/include/slate/addmod.hh new file mode 100644 index 000000000..fe91a6f42 --- /dev/null +++ b/include/slate/addmod.hh @@ -0,0 +1,88 @@ +// Copyright (c) 2022, University of Tennessee. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause +// This program is free software: you can redistribute it and/or modify it under +// the terms of the BSD 3-Clause license. See the accompanying LICENSE file. + +//------------------------------------------------------------------------------ +/// @file +/// +#ifndef SLATE_ADDMOD_HH +#define SLATE_ADDMOD_HH + +#include "slate/Matrix.hh" + +#include + +namespace slate { + +//------------------------------------------------------------------------------ +// auxiliary type for modifications and Woodbury matrices + +template +class AddModFactors { + using real_t = blas::real_type; +public: + int64_t block_size; + int64_t num_modifications; + BlockFactor factorType; + + Matrix A; + Matrix U_factors; + Matrix VT_factors; + std::vector> singular_values; + std::vector> modifications; + std::vector> modification_indices; + Matrix capacitance_matrix; + Pivots capacitance_pivots; + + Matrix S_VT_Rinv; + Matrix Linv_U; +}; + +//------------------------------------------------------------------------------ +// Routines + +template +void gesv_addmod(Matrix& A, AddModFactors& W, Matrix& B, + Options const& opts = Options()); + +template +void gesv_addmod_ir( Matrix& A, AddModFactors& W, + Matrix& B, + Matrix& X, + int& iter, + Options const& opts); + +template +void getrf_addmod(Matrix& A, AddModFactors& W, + Options const& opts = Options()); + +template +void getrs_addmod(AddModFactors& W, + Matrix& B, + Options const& opts); + +template +void trsm_addmod( + Side side, Uplo uplo, + scalar_t alpha, AddModFactors& W, + Matrix& B, + Options const& opts = Options()); + +template +void trsmA_addmod( + Side side, Uplo uplo, + scalar_t alpha, AddModFactors& W, + Matrix& B, + Options const& opts = Options()); + +template +void trsmB_addmod( + Side side, Uplo uplo, + scalar_t alpha, AddModFactors& W, + Matrix& B, + Options const& opts = Options()); + +} // namespace slate + +#endif // SLATE_ADDMOD_HH diff --git a/include/slate/enums.hh b/include/slate/enums.hh index 09f37ced2..73f94faa4 100644 --- a/include/slate/enums.hh +++ b/include/slate/enums.hh @@ -9,6 +9,7 @@ #ifndef SLATE_ENUMS_HH #define SLATE_ENUMS_HH +#include #include #include @@ -74,6 +75,9 @@ enum class Option : char { MaxIterations, ///< maximum iteration count UseFallbackSolver, ///< whether to fallback to a robust solver if iterations do not converge PivotThreshold, ///< threshold for pivoting, >= 0, <= 1 + AdditiveTolerance, ///< tolerance for additive modification, >= 0 + UseWoodbury, ///< whether to apply the Woodbury formula + BlockFactor, ///< how to factor the diagonal blocks in the addmod solver // Printing parameters PrintVerbose = 50, ///< verbose, 0: no printing, @@ -87,7 +91,6 @@ enum class Option : char { PrintWidth, ///< width print format specifier PrintPrecision, ///< precision print format specifier ///< For correct printing, PrintWidth = PrintPrecision + 6. - // Methods, listed alphabetically. MethodCholQR = 60, ///< Select the algorithm to compute A^H * A MethodEig, ///< Select the algorithm to compute eigenpairs of tridiagonal matrix @@ -144,6 +147,44 @@ enum MOSI { typedef short MOSI_State; + +//------------------------------------------------------------------------------ +enum class BlockFactor : char { + SVD, + QLP, + QRCP, + QR +}; +inline BlockFactor str2blockfactor(const char* method) +{ + std::string method_ = method; + std::transform( + method_.begin(), method_.end(), method_.begin(), ::tolower ); + + if (method_ == "svd") + return BlockFactor::SVD; + else if (method_ == "qlp") + return BlockFactor::QLP; + else if (method_ == "qrcp") + return BlockFactor::QRCP; + else if (method_ == "qr") + return BlockFactor::QR; + else + // throw slate::Exception("unknown BlockFactor"); + return BlockFactor::SVD; +} + +inline const char* blockfactor2str(BlockFactor method) +{ + switch (method) { + case BlockFactor::SVD: return "SVD"; + case BlockFactor::QLP: return "QLP"; + case BlockFactor::QRCP: return "QRCP"; + case BlockFactor::QR: return "QR"; + default: return "error"; + } +} + } // namespace slate #endif // SLATE_ENUMS_HH diff --git a/include/slate/internal/device.hh b/include/slate/internal/device.hh index a7092b7f0..0d6e42916 100644 --- a/include/slate/internal/device.hh +++ b/include/slate/internal/device.hh @@ -248,6 +248,26 @@ void trnorm( blas::real_type* values, int64_t ldv, int64_t batch_count, blas::Queue& queue); +//------------------------------------------------------------------------------ +template +void batch_trsm_addmod( + BlockFactor factorType, + blas::Layout layout, + blas::Side side, + blas::Uplo uplo, + int64_t mb, + int64_t nb, + int64_t ib, + scalar_t alpha, + std::vector Aarray, int64_t ldda, + std::vector Uarray, int64_t lddu, + std::vector VTarray, int64_t lddvt, + std::vector*> Sarray, + std::vector Barray, int64_t lddb, + std::vector dwork, + const size_t batch, + blas::Queue &queue ); + //------------------------------------------------------------------------------ // In-place, square. template diff --git a/include/slate/method.hh b/include/slate/method.hh index e4076c2f9..3a19fc401 100644 --- a/include/slate/method.hh +++ b/include/slate/method.hh @@ -32,11 +32,16 @@ namespace MethodTrsm { const Method TrsmB = 2; ///< Select trsmB algorithm template - inline Method select_algo(TA& A, TB& B, Options const& opts) { + inline Method select_algo(TA& A, TB& B, Side side, Options const& opts) { Target target = get_option( opts, Option::Target, Target::HostTask ); int n_devices = A.num_devices(); - Method method = (B.nt() < 2 ? TrsmA : TrsmB); + Method method; + if (side == Side::Left) { + method = (A.nt()>B.nt() && B.nt() < 2 ? TrsmA : TrsmB); + } else { + method = (A.mt()>B.mt() && B.mt() < 2 ? TrsmA : TrsmB); + } if (method == TrsmA && target == Target::Devices && n_devices > 1) method = TrsmB; diff --git a/include/slate/slate.hh b/include/slate/slate.hh index 2229958cd..6de6aa97c 100644 --- a/include/slate/slate.hh +++ b/include/slate/slate.hh @@ -21,6 +21,8 @@ #include "slate/types.hh" #include "slate/print.hh" +#include "slate/addmod.hh" + //------------------------------------------------------------------------------ /// @namespace slate /// SLATE's top-level namespace. @@ -644,6 +646,9 @@ void getri( Matrix& B, Options const& opts = Options()); +//----------------------------------------- +// LU with additive modifications + //----------------------------------------- // Cholesky diff --git a/include/slate/types.hh b/include/slate/types.hh index 5d65b5a39..c39f60790 100644 --- a/include/slate/types.hh +++ b/include/slate/types.hh @@ -46,6 +46,9 @@ public: OptionValue(Target t) : i_(int(t)) {} + OptionValue(BlockFactor f) : i_(int(f)) + {} + OptionValue(MethodEig m) : i_(int(m)) {} diff --git a/src/cuda/device_trsm_addmod.cu b/src/cuda/device_trsm_addmod.cu new file mode 100644 index 000000000..8741f7d3e --- /dev/null +++ b/src/cuda/device_trsm_addmod.cu @@ -0,0 +1,1122 @@ + +#include +#include +#include +#include + +#include "slate/Exception.hh" +#include "slate/internal/device.hh" + +#include "device_util.cuh" + +#include + + +namespace slate { +namespace device { + +// templated casting from C++ types to cuda types +template +static scalar_t to_cutype(scalar_t z) { + return z; +} +static cuFloatComplex to_cutype(std::complex z) { + return make_cuFloatComplex(z.real(), z.imag()); +} +static cuFloatComplex** to_cutype(std::complex** z) { + return (cuFloatComplex**)z; +} +static cuDoubleComplex to_cutype(std::complex z) { + return make_cuDoubleComplex(z.real(), z.imag()); +} +static cuDoubleComplex** to_cutype(std::complex** z) { + return (cuDoubleComplex**)z; +} + +template +__device__ scalar_t as_scalar(double r) { + return scalar_t(r); +} +template<> +__device__ cuDoubleComplex as_scalar(double r) { + return make_cuDoubleComplex(r, 0.0); +} +template<> +__device__ cuFloatComplex as_scalar(double r) { + return make_cuFloatComplex(r, 0.0); +} + + +template +void batch_gemm( + blas::Layout layout, + blas::Op transA, + blas::Op transB, + int64_t mb, + int64_t nb, + int64_t kb, + scalar_t alpha, + std::vector& Aarray, int64_t Ai, int64_t Aj, int64_t ldda, + std::vector& Barray, int64_t Bi, int64_t Bj, int64_t lddb, + scalar_t beta, + std::vector& Carray, int64_t Ci, int64_t Cj, int64_t lddc, + const size_t batch, + blas::Queue &queue ) +{ + std::vector transA_v {transA}; + std::vector transB_v {transB}; + std::vector m_v {mb}; + std::vector n_v {nb}; + std::vector k_v {kb}; + std::vector alpha_v {alpha}; + std::vector beta_v {beta }; + std::vector info (0); + + std::vector Aarray_v(batch); + std::vector ldda_v {ldda}; + std::vector Barray_v(batch); + std::vector lddb_v {lddb}; + std::vector Carray_v(batch); + std::vector lddc_v {lddc}; + + if (layout == blas::Layout::ColMajor) { + for (size_t i = 0; i < batch; ++i) { + Aarray_v[i] = Aarray[i] + Ai + Aj*ldda; + Barray_v[i] = Barray[i] + Bi + Bj*lddb; + Carray_v[i] = Carray[i] + Ci + Cj*lddc; + } + } + else { + for (size_t i = 0; i < batch; ++i) { + Aarray_v[i] = Aarray[i] + Ai*ldda + Aj; + Barray_v[i] = Barray[i] + Bi*lddb + Bj; + Carray_v[i] = Carray[i] + Ci*lddc + Cj; + } + } + + blas::batch::gemm( + layout, transA_v, transB_v, m_v, n_v, k_v, + alpha_v, Aarray_v, ldda_v, + Barray_v, lddb_v, + beta_v, Carray_v, lddc_v, + batch, info, queue); +} +template +void batch_trsm( + blas::Layout layout, + blas::Side side, + blas::Uplo uplo, + blas::Op trans, + blas::Diag diag, + int64_t mb, + int64_t nb, + scalar_t alpha, + std::vector& Aarray, int64_t Ai, int64_t Aj, int64_t ldda, + std::vector& Barray, int64_t Bi, int64_t Bj, int64_t lddb, + const size_t batch, + blas::Queue &queue ) +{ + std::vector side_v {side}; + std::vector uplo_v {uplo}; + std::vector trans_v {trans}; + std::vector diag_v {diag}; + std::vector m_v {mb}; + std::vector n_v {nb}; + std::vector alpha_v {alpha}; + std::vector info (0); + + std::vector Aarray_v(batch); + std::vector ldda_v {ldda}; + std::vector Barray_v(batch); + std::vector lddb_v {lddb}; + + if (layout == blas::Layout::ColMajor) { + for (size_t i = 0; i < batch; ++i) { + Aarray_v[i] = Aarray[i] + Ai + Aj*ldda; + Barray_v[i] = Barray[i] + Bi + Bj*lddb; + } + } + else { + for (size_t i = 0; i < batch; ++i) { + Aarray_v[i] = Aarray[i] + Ai*ldda + Aj; + Barray_v[i] = Barray[i] + Bi*lddb + Bj; + } + } + + blas::batch::trsm( + layout, side_v, uplo_v, trans_v, diag_v, m_v, n_v, + alpha_v, Aarray_v, ldda_v, + Barray_v, lddb_v, + batch, info, queue); +} + +template +__device__ void tb_upper_right(int mb, int nb, + scalar_t alpha, scalar_t* __restrict__ dB, int64_t lddb, + scalar_t* __restrict__ dVT, int64_t lddvt, + blas::real_type* __restrict__ dS) +{ + // Due to linking, dynamic shared memory can't be declared as scalar_t + extern __shared__ char shared_ptr[]; + + scalar_t* sW = (scalar_t*) shared_ptr; + int ldsw = 32; + + const scalar_t zero = as_scalar(0.0); + + const int offseti = threadIdx.x; + const int stridei = 32; //blockDim.x; + const int offsetj = threadIdx.y; + const int stridej = blockDim.y; + + for (int ii = 0; ii < mb; ii += stridei) { + int iib = min(32, mb-ii); + + int i = offseti; + if (i < iib) { + for (int j = offsetj; j < nb; j += stridej) { + + scalar_t sum = zero; + scalar_t* dBik = dB + ii+i; + scalar_t* dVTjk = dVT + j; + for (int k = 0; k < nb; k += 1) { + sum += dBik[0] * conj(dVTjk[0]); + dBik += lddb; + dVTjk += lddvt; + } + + blas::real_type S_j = dS[j]; + sW[i + j*ldsw] = alpha*sum / S_j; + } + } + __syncthreads(); + if (i < iib) { + for (int j = offsetj; j < nb; j += stridej) { + dB[ii+i + j*lddb] = sW[i + j*ldsw]; + } + } + } +} + +// Compute a gemm within the threadblock +// beta=0 +template +__device__ void tb_gemm(bool transA, bool transB, + int mb, int nb, int kb, + scalar_t alpha, scalar_t* __restrict__ dA, int64_t ldda, + scalar_t* __restrict__ dB, int64_t lddb, + scalar_t* __restrict__ dC, int64_t lddc) +{ + + // Due to linking, dynamic shared memory can't be declared as scalar_t + extern __shared__ char shared_ptr[]; + + const scalar_t zero = as_scalar(0.0); + + const int offseti = threadIdx.x; + const int stridei = 32; //blockDim.x; + const int offsetj = threadIdx.y; + const int stridej = blockDim.y; + + if (transA) { + if (transB) { + for (int i = offseti; i < mb; i += stridei) { + for (int j = offsetj; j < nb; j += stridej) { + scalar_t sum = zero; + for (int k = 0; k < kb; k += 1) { + sum += conj(dA[k + i*ldda]) * conj(dB[j + k*lddb]); + } + dC[i + j*lddc] = alpha*sum; + } + } + } else { + scalar_t* sA = (scalar_t*)shared_ptr; + const int ldsa = 33; + scalar_t* sB = sA + ldsa*32; + const int ldsb = 33; + scalar_t* sC = sB + ldsb*32; + const int ldsc = 33; + + for (int ii = 0; ii < mb; ii += 32) { + int iib = min(32, mb-ii); + + for (int j = offsetj; j < nb; j += stridej) { + sC[offseti + j*ldsc] = zero; + } + + for (int kk = 0; kk < kb; kk += 32) { + int kkb = min(32, kb-kk); + + __syncthreads(); + + // Pad sA and sB w/ zeros up to 32 so we can read off the end + for (int i = offsetj; i < iib; i += stridej) { + if (offseti < kkb && i < iib) { + sA[offseti + i*ldsa] = dA[kk+offseti + (ii+i)*ldda]; + } else { + sA[offseti + i*ldsa] = zero; + } + } + for (int j = offsetj; j < nb; j += stridej) { + if (offseti < kkb && j < nb) { + sB[offseti + j*ldsb] = dB[kk+offseti + j*lddb]; + } else { + sB[offseti + j*ldsb] = zero; + } + } + __syncthreads(); + int i = offseti; + if (i < iib) { + for (int j = offsetj; j < nb; j += stridej) { + scalar_t sum = zero; + #pragma unroll 8 + for (int k = 0; k < 32; k += 1) { + sum += conj(sA[k + i*ldsa]) * sB[k + j*ldsb]; + } + sC[i + j*ldsc] += alpha*sum; + } + } + } + if (offseti < iib) { + for (int j = offsetj; j < nb; j += stridej) { + dC[ii+offseti + j*lddc] = sC[offseti + j*ldsc]; + } + } + } + } + } else { + if (transB) { + for (int i = offseti; i < mb; i += stridei) { + for (int j = offsetj; j < nb; j += stridej) { + scalar_t sum = zero; + scalar_t* dAik = dA + i; + scalar_t* dBjk = dB + j; + for (int k = 0; k < kb; k += 1) { + sum += dAik[0] * conj(dBjk[0]); + dAik += ldda; + dBjk += lddb; + } + dC[i + j*lddc] = alpha*sum; + } + } + } else { + for (int i = offseti; i < mb; i += stridei) { + for (int j = offsetj; j < nb; j += stridej) { + scalar_t sum = zero; + for (int k = 0; k < kb; k += 1) { + sum += dA[i + k*ldda] * dB[k + j*lddb]; + } + dC[i + j*lddc] = alpha*sum; + } + } + } + } +} + +template +__device__ void tb_copy( + int mb, + int nb, + scalar_t* __restrict__ dA, int64_t ldda, + scalar_t* __restrict__ dB, int64_t lddb) +{ + int offseti = threadIdx.x; + int stridei = blockDim.x; + int offsetj = threadIdx.y; + int stridej = blockDim.y; + + for (int i = offseti; i < mb; i += stridei) { + for (int j = offsetj; j < nb; j += stridej) { + dB[i + j*lddb] = dA[i + j*ldda]; + } + } +} + +template +__device__ void tb_scale_copy( + bool isLeft, + int mb, + int nb, + blas::real_type* dS, + scalar_t* __restrict__ dA, int64_t ldda, + scalar_t* __restrict__ dB, int64_t lddb) +{ + int offseti = threadIdx.x; + int stridei = blockDim.x; + int offsetj = threadIdx.y; + int stridej = blockDim.y; + + if (isLeft) { + for (int i = offseti; i < mb; i += stridei) { + blas::real_type S_i = dS[i]; + blas::real_type inv_S_i = 1/S_i; + for (int j = offsetj; j < nb; j += stridej) { + dB[i + j*lddb] = dA[i + j*ldda] * inv_S_i; + } + } + } else { + for (int j = offsetj; j < nb; j += stridej) { + blas::real_type S_j = dS[j]; + blas::real_type inv_S_j = 1/S_j; + for (int i = offseti; i < mb; i += stridei) { + dB[i + j*lddb] = dA[i + j*ldda] * inv_S_j; + } + } + } +} + +template +__global__ void __launch_bounds__(512,2) batch_diag_kernel( + int mb, int nb, + scalar_t alpha, + scalar_t** dUarray, int64_t Ui, int64_t lddu, + scalar_t** dVTarray, int64_t VTi, int64_t lddvt, + blas::real_type** dSarray, + scalar_t** dBarray, int64_t Bi, int64_t Bj, int64_t lddb, + scalar_t** dWarray, int64_t lddw) +{ + using real_t = blas::real_type; + + int batch = blockIdx.x; + + scalar_t* B_local = dBarray[batch] + Bi + Bj*lddb; + scalar_t* W_local = dWarray[batch]; + + if (!isUpper && !isLeft) { // lower right + + int step = (mb-1)/gridDim.y + 1; + B_local += step * blockIdx.y; + W_local += step * blockIdx.y; + int mb_local = min(step, mb - step*blockIdx.y); + + scalar_t* U_local = dUarray[batch] + Ui + Ui*lddu; + tb_gemm(false, true, mb_local, nb, nb, + alpha, B_local, lddb, + U_local, lddu, + W_local, lddw); + __syncthreads(); + tb_copy(mb_local, nb, W_local, lddw, B_local, lddb); + + } else if (!isUpper && isLeft) { // lower left + + int step = (nb-1)/gridDim.y + 1; + B_local += step * blockIdx.y * lddb; + W_local += step * blockIdx.y * lddw; + int nb_local = min(step, nb - step*blockIdx.y); + + scalar_t* U_local = dUarray[batch] + Ui + Ui*lddu; + tb_copy(mb, nb_local, B_local, lddb, W_local, lddw); + // tb_gemm(t, f) synchronizes before accessing A or B + tb_gemm(true, false, mb, nb_local, mb, + alpha, U_local, lddu, + W_local, lddw, + B_local, lddb); + + } else if (isUpper && !isLeft) { // upper right + + int step = (mb-1)/gridDim.y + 1; + B_local += step * blockIdx.y; + int mb_local = min(step, mb - step*blockIdx.y); + + scalar_t* VT_local = dVTarray[batch] + VTi + VTi*lddvt; + real_t* S_local = dSarray[batch] + VTi; + tb_upper_right(mb_local, nb, + alpha, B_local, lddb, + VT_local, lddvt, + S_local); + + } else if (isUpper && isLeft) { // upper left + + int step = (nb-1)/gridDim.y + 1; + B_local += step * blockIdx.y * lddb; + W_local += step * blockIdx.y * lddw; + int nb_local = min(step, nb - step*blockIdx.y); + + scalar_t* VT_local = dVTarray[batch] + VTi + VTi*lddvt; + real_t* S_local = dSarray[batch] + VTi; + tb_scale_copy(true, mb, nb_local, + S_local, + B_local, lddb, + W_local, lddw); + // tb_gemm(t, f) synchronizes before accessing A or B + tb_gemm(true, false, mb, nb_local, mb, + alpha, VT_local, lddvt, + W_local, lddw, + B_local, lddb); + } +} + +template +void batch_trsm_addmod_diag( + blas::Layout layout, + //blas::Side side, + //blas::Uplo uplo, + int64_t mb, + int64_t nb, + scalar_t alpha, + std::vector& Uarray, int64_t Ui, int64_t Uj, int64_t lddu, + std::vector& VTarray, int64_t VTi, int64_t VTj, int64_t lddvt, + std::vector*>& Sarray, int64_t Si, + std::vector& Barray, int64_t Bi, int64_t Bj, int64_t lddb, + std::vector& Warray, int64_t lddw, + const size_t batch, + blas::Queue &queue ) +{ + using real_t = blas::real_type; + + assert(layout == blas::Layout::ColMajor); + assert(blas::MaxBatchChunk >= batch); + constexpr int isUpper = uplo == blas::Uplo::Upper; + constexpr int isLeft = side == blas::Side::Left; + + assert(Ui == Uj && VTi == VTj && VTi == Si); + + queue.work_ensure_size( 2*batch ); + + scalar_t** dBarray = (scalar_t**)queue.work(); + scalar_t** dWarray = dBarray + batch; + // need either U or (VT and S), can use the same dev ptr array + scalar_t** dUarray = dWarray + batch; + scalar_t** dVTarray = dUarray; + real_t** dSarray = (real_t**)(dUarray + batch); + + blas::device_copy_vector( batch, Barray.data(), 1, dBarray, 1, queue ); + if (!isUpper || isLeft) { + // upper right doesn't need workspace array + blas::device_copy_vector( batch, Warray.data(), 1, dWarray, 1, queue ); + } + if (isUpper) { + blas::device_copy_vector( batch, VTarray.data(), 1, dVTarray, 1, queue ); + blas::device_copy_vector( batch, Sarray.data(), 1, dSarray, 1, queue ); + } else { + blas::device_copy_vector( batch, Uarray.data(), 1, dUarray, 1, queue ); + } + + int nrhs = isLeft ? nb : mb; + + dim3 grid_dim; + grid_dim.x = batch; + grid_dim.y = (nrhs-1)/32 + 1; + + dim3 block_dim; + block_dim.x = 32; + block_dim.y = 16; + + int shmem_size = (isLeft ? 3*32*33 : (isUpper ? 32*nb : 0))*sizeof(scalar_t); + + batch_diag_kernel<<< grid_dim, block_dim, shmem_size, queue.stream() >>>( + mb, nb, + to_cutype(alpha), + to_cutype(dUarray), Ui, lddu, + to_cutype(dVTarray), VTi, lddvt, + dSarray, + to_cutype(dBarray), Bi, Bj, lddb, + to_cutype(dWarray), lddw); +} + +template +__global__ void __launch_bounds__(512,4) batch_trsm_scale_copy_kernel( + int64_t mb, + int64_t nb, + blas::real_type** dSarray, int64_t Si, + scalar_t** dAarray, int64_t Ai, int64_t Aj, int64_t ldda, + scalar_t** dBarray, int64_t Bi, int64_t Bj, int64_t lddb) +{ + using real_t = blas::real_type; + + int batch = blockIdx.x; + int blkidx_m = blockIdx.y; + int blksiz_m = gridDim.y; + int blkidx_n = blockIdx.z; + int blksiz_n = gridDim.z; + + int64_t step_m = (mb-1)/blksiz_m + 1; + int64_t m_offset = step_m*blkidx_m; + int64_t mb_local = min(step_m, mb - m_offset); + + int64_t step_n = (nb-1)/blksiz_n + 1; + int64_t n_offset = step_n*blkidx_n; + int64_t nb_local = min(step_n, nb - n_offset); + + + real_t* S_local = dSarray[batch] + Si + (isLeft ? m_offset : n_offset); + scalar_t* A_local = dAarray[batch] + (Ai+m_offset) + (Aj+n_offset)*ldda; + scalar_t* B_local = dBarray[batch] + (Bi+m_offset) + (Bj+n_offset)*lddb; + + tb_scale_copy(isLeft, mb_local, nb_local, S_local, A_local, ldda, B_local, lddb); +} + +template +void batch_scale_copy( + blas::Layout layout, + blas::Side side, + int64_t mb, + int64_t nb, + std::vector*>& Sarray, int64_t Si, + std::vector& Aarray, int64_t Ai, int64_t Aj, int64_t ldda, + std::vector& Barray, int64_t Bi, int64_t Bj, int64_t lddb, + const size_t batch, + blas::Queue &queue ) +{ + using real_t = blas::real_type; + + assert(layout == blas::Layout::ColMajor); + assert(blas::MaxBatchChunk >= batch); + queue.work_ensure_size( 2*batch ); + + scalar_t** dAarray = (scalar_t**)queue.work(); + scalar_t** dBarray = dAarray + batch; + real_t** dSarray = (real_t**)(dBarray + batch); + + blas::device_copy_vector( batch, Sarray.data(), 1, dSarray, 1, queue ); + blas::device_copy_vector( batch, Aarray.data(), 1, dAarray, 1, queue ); + blas::device_copy_vector( batch, Barray.data(), 1, dBarray, 1, queue ); + + dim3 grid_dim; + grid_dim.x = batch; + grid_dim.y = (mb-1)/64 + 1; + grid_dim.z = (nb-1)/64 + 1; + + dim3 block_dim; + block_dim.x = 32; + block_dim.y = 16; + + if (side == blas::Side::Left) { + batch_trsm_scale_copy_kernel<<< grid_dim, block_dim, 0, queue.stream()>>>( + mb, nb, + to_cutype(dSarray), Si, + to_cutype(dAarray), Ai, Aj, ldda, + to_cutype(dBarray), Bi, Bj, lddb); + } else { + batch_trsm_scale_copy_kernel<<< grid_dim, block_dim, 0, queue.stream()>>>( + mb, nb, + to_cutype(dSarray), Si, + to_cutype(dAarray), Ai, Aj, ldda, + to_cutype(dBarray), Bi, Bj, lddb); + } +} + +template +__global__ void __launch_bounds__(512,4) batch_trsm_copy_kernel( + int64_t mb, + int64_t nb, + scalar_t** dAarray, int64_t Ai, int64_t Aj, int64_t ldda, + scalar_t** dBarray, int64_t Bi, int64_t Bj, int64_t lddb) +{ + using real_t = blas::real_type; + + int batch = blockIdx.x; + int blkidx_m = blockIdx.y; + int blksiz_m = gridDim.y; + int blkidx_n = blockIdx.z; + int blksiz_n = gridDim.z; + + int64_t step_m = (mb-1)/blksiz_m + 1; + int64_t m_offset = step_m*blkidx_m; + int64_t mb_local = min(step_m, mb - m_offset); + + int64_t step_n = (nb-1)/blksiz_n + 1; + int64_t n_offset = step_n*blkidx_n; + int64_t nb_local = min(step_n, nb - n_offset); + + + scalar_t* A_local = dAarray[batch] + (Ai+m_offset) + (Aj+n_offset)*ldda; + scalar_t* B_local = dBarray[batch] + (Bi+m_offset) + (Bj+n_offset)*lddb; + + tb_copy(mb_local, nb_local, A_local, ldda, B_local, lddb); +} + +template +void batch_copy( + blas::Layout layout, + int64_t mb, + int64_t nb, + std::vector& Aarray, int64_t Ai, int64_t Aj, int64_t ldda, + std::vector& Barray, int64_t Bi, int64_t Bj, int64_t lddb, + const size_t batch, + blas::Queue &queue ) +{ + assert(layout == blas::Layout::ColMajor); + assert(blas::MaxBatchChunk >= batch); + queue.work_ensure_size( 2*batch ); + + scalar_t** dAarray = (scalar_t**)queue.work(); + scalar_t** dBarray = dAarray + batch; + + blas::device_copy_vector( batch, Aarray.data(), 1, dAarray, 1, queue ); + blas::device_copy_vector( batch, Barray.data(), 1, dBarray, 1, queue ); + + dim3 grid_dim; + grid_dim.x = batch; + grid_dim.y = (mb-1)/64 + 1; + grid_dim.z = (nb-1)/64 + 1; + + dim3 block_dim; + block_dim.x = 32; + block_dim.y = 16; + + batch_trsm_copy_kernel<<< grid_dim, block_dim, 0, queue.stream()>>>( + mb, nb, + to_cutype(dAarray), Ai, Aj, ldda, + to_cutype(dBarray), Bi, Bj, lddb); +} + + +template +void batch_trsm_addmod_rec( + blas::Layout layout, + blas::Side side, + blas::Uplo uplo, + int64_t mb, + int64_t nb, + int64_t ib, + scalar_t alpha, + std::vector& Aarray, int64_t Ai, int64_t Aj, int64_t ldda, + std::vector& Uarray, int64_t Ui, int64_t Uj, int64_t lddu, + std::vector& VTarray, int64_t VTi, int64_t VTj, int64_t lddvt, + std::vector*>& Sarray, int64_t Si, + std::vector& Barray, int64_t Bi, int64_t Bj, int64_t lddb, + std::vector& Warray, + const size_t batch, + blas::Queue &queue ) +{ + scalar_t one = 1.0; + scalar_t zero = 0.0; + + // Threshold to switch to cuBLAS for the diagonal GEMM + constexpr int64_t cublas_threshold = 80; + + bool isUpper = uplo == blas::Uplo::Upper; + bool isLeft = side == blas::Side::Left; + + int64_t lddw = (layout == blas::Layout::ColMajor) ? mb : nb; + + blas::Op trans_op = std::is_same>::value ? blas::Op::Trans : blas::Op::ConjTrans; + + if (isUpper && isLeft) { + if (mb <= ib) { + // halt recursion + if constexpr (factorType == BlockFactor::SVD) { + if (ib < cublas_threshold) { + batch_trsm_addmod_diag( + layout, mb, nb, alpha, + Uarray, Ui, Uj, lddu, + VTarray, VTi, VTj, lddvt, + Sarray, Si, + Barray, Bi, Bj, lddb, + Warray, lddw, + batch, queue); + } else { + batch_scale_copy(layout, side, mb, nb, + Sarray, Si, + Barray, Bi, Bj, lddb, + Warray, 0, 0, lddw, + batch, queue); + batch_gemm(layout, trans_op, blas::Op::NoTrans, mb, nb, mb, + alpha, VTarray, VTi, VTj, lddvt, + Warray, 0, 0, lddw, + zero, Barray, Bi, Bj, lddb, + batch, queue); + } + } + else if constexpr (factorType == BlockFactor::QLP + || factorType == BlockFactor::QRCP) { + auto uplo = factorType == BlockFactor::QLP ? blas::Uplo::Lower : blas::Uplo::Upper; + batch_trsm(layout, blas::Side::Left, uplo, blas::Op::NoTrans, blas::Diag::NonUnit, mb, nb, + one, Aarray, Ai, Aj, ldda, + Barray, Bi, Bj, lddb, + batch, queue); + batch_copy(layout, mb, nb, + Barray, Bi, Bj, lddb, + Warray, 0, 0, lddw, + batch, queue); + batch_gemm(layout, trans_op, blas::Op::NoTrans, mb, nb, mb, + alpha, VTarray, VTi, VTj, lddvt, + Warray, 0, 0, lddw, + zero, Barray, Bi, Bj, lddb, + batch, queue); + } + else if constexpr (factorType == BlockFactor::QR) { + batch_trsm(layout, blas::Side::Left, blas::Uplo::Upper, blas::Op::NoTrans, blas::Diag::NonUnit, mb, nb, + one, Aarray, Ai, Aj, ldda, + Barray, Bi, Bj, lddb, + batch, queue); + } + else { + static_assert(factorType == BlockFactor::SVD, "Unsupported block factor"); + } + } + else { + // recurse + int64_t m1 = (((mb-1)/ib+1)/2) * ib; // half the tiles, rounded down + int64_t m2 = mb-m1; + + batch_trsm_addmod_rec(layout, side, uplo, m2, nb, ib, alpha, + Aarray, Ai+m1, Aj+m1, ldda, + Uarray, Ui+m1, Uj+m1, lddu, + VTarray, VTi+m1, VTj+m1, lddvt, + Sarray, Si+m1, + Barray, Bi+m1, Bj, lddb, + Warray, batch, queue); + + batch_gemm(layout, blas::Op::NoTrans, blas::Op::NoTrans, m1, nb, m2, + -one, Aarray, Ai, Aj+m1, ldda, + Barray, Bi+m1, Bj, lddb, + alpha, Barray, Bi, Bj, lddb, + batch, queue); + + batch_trsm_addmod_rec(layout, side, uplo, m1, nb, ib, one, + Aarray, Ai, Aj, ldda, + Uarray, Ui, Uj, lddu, + VTarray, VTi, VTj, lddvt, + Sarray, Si, + Barray, Bi, Bj, lddb, + Warray, batch, queue); + } + } + else if (isUpper && !isLeft) { + if (nb <= ib) { + // halt recursion + if constexpr (factorType == BlockFactor::SVD) { + if (ib < cublas_threshold) { + batch_trsm_addmod_diag( + layout, mb, nb, alpha, + Uarray, Ui, Uj, lddu, + VTarray, VTi, VTj, lddvt, + Sarray, Si, + Barray, Bi, Bj, lddb, + Warray, lddw, + batch, queue); + } else { + batch_gemm(layout, blas::Op::NoTrans, trans_op, mb, nb, nb, + alpha, Barray, Bi, Bj, lddb, + VTarray, VTi, VTj, lddvt, + zero, Warray, 0, 0, lddw, + batch, queue); + batch_scale_copy(layout, side, mb, nb, + Sarray, Si, + Warray, 0, 0, lddw, + Barray, Bi, Bj, lddb, + batch, queue); + } + } + else if constexpr (factorType == BlockFactor::QLP + || factorType == BlockFactor::QRCP) { + batch_gemm(layout, blas::Op::NoTrans, trans_op, mb, nb, nb, + alpha, Barray, Bi, Bj, lddb, + VTarray, VTi, VTj, lddvt, + zero, Warray, 0, 0, lddw, + batch, queue); + batch_copy(layout, mb, nb, + Warray, 0, 0, lddw, + Barray, Bi, Bj, lddb, + batch, queue); + auto uplo = factorType == BlockFactor::QLP ? blas::Uplo::Lower : blas::Uplo::Upper; + batch_trsm(layout, blas::Side::Right, uplo, blas::Op::NoTrans, blas::Diag::NonUnit, mb, nb, + one, Aarray, Ai, Aj, ldda, + Barray, Bi, Bj, lddb, + batch, queue); + } + else if constexpr (factorType == BlockFactor::QR) { + batch_trsm(layout, blas::Side::Right, blas::Uplo::Upper, blas::Op::NoTrans, blas::Diag::NonUnit, mb, nb, + one, Aarray, Ai, Aj, ldda, + Barray, Bi, Bj, lddb, + batch, queue); + } + else { + static_assert(factorType == BlockFactor::SVD, "Unsupported block factor"); + } + } + else { + // recurse + int64_t n1 = (((nb-1)/ib)/2+1) * ib; // half the tiles, rounded up + int64_t n2 = nb-n1; + + batch_trsm_addmod_rec(layout, side, uplo, mb, n1, ib, alpha, + Aarray, Ai, Aj, ldda, + Uarray, Ui, Uj, lddu, + VTarray, VTi, VTj, lddvt, + Sarray, Si, + Barray, Bi, Bj, lddb, + Warray, batch, queue); + + batch_gemm(layout, blas::Op::NoTrans, blas::Op::NoTrans, mb, n2, n1, + -one, Barray, Bi, Bj, lddb, + Aarray, Ai, Aj+n1, ldda, + alpha, Barray, Bi, Bj+n1, lddb, + batch, queue); + + batch_trsm_addmod_rec(layout, side, uplo, mb, n2, ib, one, + Aarray, Ai+n1, Aj+n1, ldda, + Uarray, Ui+n1, Uj+n1, lddu, + VTarray, VTi+n1, VTj+n1, lddvt, + Sarray, Si+n1, + Barray, Bi, Bj+n1, lddb, + Warray, batch, queue); + } + } + else if (!isUpper && isLeft) { + if (mb <= ib) { + // halt recursion + if constexpr (factorType == BlockFactor::SVD + || factorType == BlockFactor::QLP + || factorType == BlockFactor::QRCP + || factorType == BlockFactor::QR) { + if (ib < cublas_threshold) { + batch_trsm_addmod_diag( + layout, /*side, uplo,*/ mb, nb, alpha, + Uarray, Ui, Uj, lddu, + VTarray, VTi, VTj, lddvt, + Sarray, Si, + Barray, Bi, Bj, lddb, + Warray, lddw, + batch, queue); + } else { + batch_copy(layout, mb, nb, + Barray, Bi, Bj, lddb, + Warray, 0, 0, lddw, + batch, queue); + batch_gemm(layout, trans_op, blas::Op::NoTrans, mb, nb, mb, + alpha, Uarray, Ui, Uj, lddu, + Warray, 0, 0, lddw, + zero, Barray, Bi, Bj, lddb, + batch, queue); + } + } + else { + static_assert(factorType == BlockFactor::SVD, "Unsupported block factor"); + } + } + else { + // recurse + int64_t m1 = (((mb-1)/ib)/2+1) * ib; // half the tiles, rounded up + int64_t m2 = mb-m1; + + batch_trsm_addmod_rec(layout, side, uplo, m1, nb, ib, alpha, + Aarray, Ai, Aj, ldda, + Uarray, Ui, Uj, lddu, + VTarray, VTi, VTj, lddvt, + Sarray, Si, + Barray, Bi, Bj, lddb, + Warray, batch, queue); + + batch_gemm(layout, blas::Op::NoTrans, blas::Op::NoTrans, m2, nb, m1, + -one, Aarray, Ai+m1, Aj, ldda, + Barray, Bi, Bj, lddb, + alpha, Barray, Bi+m1, Bj, lddb, + batch, queue); + + batch_trsm_addmod_rec(layout, side, uplo, m2, nb, ib, one, + Aarray, Ai+m1, Aj+m1, ldda, + Uarray, Ui+m1, Uj+m1, lddu, + VTarray, VTi+m1, VTj+m1, lddvt, + Sarray, Si+m1, + Barray, Bi+m1, Bj, lddb, + Warray, batch, queue); + } + } + else if (!isUpper && !isLeft) { + if (nb <= ib) { + // halt recursion + if constexpr (factorType == BlockFactor::SVD + || factorType == BlockFactor::QLP + || factorType == BlockFactor::QRCP + || factorType == BlockFactor::QR) { + if (ib < cublas_threshold) { + batch_trsm_addmod_diag( + layout, /*side, uplo,*/ mb, nb, alpha, + Uarray, Ui, Uj, lddu, + VTarray, VTi, VTj, lddvt, + Sarray, Si, + Barray, Bi, Bj, lddb, + Warray, lddw, + batch, queue); + } else { + batch_gemm(layout, blas::Op::NoTrans, trans_op, mb, nb, nb, + alpha, Barray, Bi, Bj, lddb, + Uarray, Ui, Uj, ldda, + zero, Warray, 0, 0, lddw, + batch, queue); + batch_copy(layout, mb, nb, + Warray, 0, 0, lddw, + Barray, Bi, Bj, lddb, + batch, queue); + } + } + else { + static_assert(factorType == BlockFactor::SVD, "Unsupported block factor"); + } + } + else { + // recurse + int64_t n1 = (((nb-1)/ib+1)/2) * ib; // half the tiles, rounded down + int64_t n2 = nb-n1; + + batch_trsm_addmod_rec(layout, side, uplo, mb, n2, ib, alpha, + Aarray, Ai+n1, Aj+n1, ldda, + Uarray, Ui+n1, Uj+n1, lddu, + VTarray, VTi+n1, VTj+n1, lddvt, + Sarray, Si+n1, + Barray, Bi, Bj+n1, lddb, + Warray, batch, queue); + + batch_gemm(layout, blas::Op::NoTrans, blas::Op::NoTrans, mb, n1, n2, + -one, Barray, Bi, Bj+n1, lddb, + Aarray, Ai+n1, Aj, ldda, + alpha, Barray, Bi, Bj, lddb, + batch, queue); + + batch_trsm_addmod_rec(layout, side, uplo, mb, n1, ib, one, + Aarray, Ai, Aj, ldda, + Uarray, Ui, Uj, lddu, + VTarray, VTi, VTj, lddvt, + Sarray, Si, + Barray, Bi, Bj, lddb, + Warray, batch, queue); + } + } +} + + + +template +void batch_trsm_addmod( + BlockFactor factorType, + blas::Layout layout, + blas::Side side, + blas::Uplo uplo, + int64_t mb, + int64_t nb, + int64_t ib, + scalar_t alpha, + std::vector Aarray, int64_t ldda, + std::vector Uarray, int64_t lddu, + std::vector VTarray, int64_t lddvt, + std::vector*> Sarray, + std::vector Barray, int64_t lddb, + std::vector Warray, + const size_t batch, + blas::Queue &queue ) +{ + // TODO could assume A, U, S are shared between all thread blocks + + if (factorType == BlockFactor::SVD) { + batch_trsm_addmod_rec(layout, side, uplo, mb, nb, ib, alpha, + Aarray, 0, 0, ldda, + Uarray, 0, 0, lddu, + VTarray, 0, 0, lddvt, + Sarray, 0, + Barray, 0, 0, lddb, + Warray, batch, queue); + } + else if (factorType == BlockFactor::QLP) { + batch_trsm_addmod_rec(layout, side, uplo, mb, nb, ib, alpha, + Aarray, 0, 0, ldda, + Uarray, 0, 0, lddu, + VTarray, 0, 0, lddvt, + Sarray, 0, + Barray, 0, 0, lddb, + Warray, batch, queue); + } + else if (factorType == BlockFactor::QRCP) { + batch_trsm_addmod_rec(layout, side, uplo, mb, nb, ib, alpha, + Aarray, 0, 0, ldda, + Uarray, 0, 0, lddu, + VTarray, 0, 0, lddvt, + Sarray, 0, + Barray, 0, 0, lddb, + Warray, batch, queue); + } + else if (factorType == BlockFactor::QR) { + batch_trsm_addmod_rec(layout, side, uplo, mb, nb, ib, alpha, + Aarray, 0, 0, ldda, + Uarray, 0, 0, lddu, + VTarray, 0, 0, lddvt, + Sarray, 0, + Barray, 0, 0, lddb, + Warray, batch, queue); + } + else { + slate_not_implemented("Only SVD is supported on device"); + } +} + +// Explicit instantiation +template +void batch_trsm_addmod( + BlockFactor factorType, + blas::Layout layout, + blas::Side side, + blas::Uplo uplo, + int64_t mb, + int64_t nb, + int64_t ib, + float alpha, + std::vector Aarray, int64_t ldda, + std::vector Uarray, int64_t lddu, + std::vector VTarray, int64_t lddvt, + std::vector Sarray, + std::vector Barray, int64_t lddb, + std::vector Warray, + const size_t batch, + blas::Queue &queue ); + +template +void batch_trsm_addmod( + BlockFactor factorType, + blas::Layout layout, + blas::Side side, + blas::Uplo uplo, + int64_t mb, + int64_t nb, + int64_t ib, + double alpha, + std::vector Aarray, int64_t ldda, + std::vector Uarray, int64_t lddu, + std::vector VTarray, int64_t lddvt, + std::vector Sarray, + std::vector Barray, int64_t lddb, + std::vector Warray, + const size_t batch, + blas::Queue &queue ); + +template +void batch_trsm_addmod( + BlockFactor factorType, + blas::Layout layout, + blas::Side side, + blas::Uplo uplo, + int64_t mb, + int64_t nb, + int64_t ib, + std::complex alpha, + std::vector*> Aarray, int64_t ldda, + std::vector*> Uarray, int64_t lddu, + std::vector*> VTarray, int64_t lddvt, + std::vector Sarray, + std::vector*> Barray, int64_t lddb, + std::vector*> Warray, + const size_t batch, + blas::Queue &queue ); + +template +void batch_trsm_addmod( + BlockFactor factorType, + blas::Layout layout, + blas::Side side, + blas::Uplo uplo, + int64_t mb, + int64_t nb, + int64_t ib, + std::complex alpha, + std::vector*> Aarray, int64_t ldda, + std::vector*> Uarray, int64_t lddu, + std::vector*> VTarray, int64_t lddvt, + std::vector Sarray, + std::vector*> Barray, int64_t lddb, + std::vector*> Warray, + const size_t batch, + blas::Queue &queue ); + +} // namespace device +} // namespace slate diff --git a/src/cuda/device_util.cuh b/src/cuda/device_util.cuh index 4a3e742d9..372c3776a 100644 --- a/src/cuda/device_util.cuh +++ b/src/cuda/device_util.cuh @@ -538,7 +538,29 @@ operator /= (cuDoubleComplex& a, const double s) return a; } -//============================================================================== +// ---------- equality +__host__ __device__ inline bool +operator == (const cuDoubleComplex a, const cuDoubleComplex b) +{ + return ( real(a) == real(b) && + imag(a) == imag(b) ); +} + +__host__ __device__ inline bool +operator == (const cuDoubleComplex a, const double s) +{ + return ( real(a) == s && + imag(a) == 0. ); +} + +__host__ __device__ inline bool +operator == (const double s, const cuDoubleComplex a) +{ + return ( real(a) == s && + imag(a) == 0. ); +} + +// ============================================================================= // complex-float // ---------- negate diff --git a/src/gesv_addmod.cc b/src/gesv_addmod.cc new file mode 100644 index 000000000..0c43cbf88 --- /dev/null +++ b/src/gesv_addmod.cc @@ -0,0 +1,59 @@ +// Copyright (c) 2017-2022, University of Tennessee. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause +// This program is free software: you can redistribute it and/or modify it under +// the terms of the BSD 3-Clause license. See the accompanying LICENSE file. + +#include "slate/slate.hh" +#include "auxiliary/Debug.hh" +#include "slate/Matrix.hh" +#include "slate/Tile_blas.hh" +#include "slate/TriangularMatrix.hh" +#include "internal/internal.hh" + +namespace slate { + +// TODO docs +template +void gesv_addmod(Matrix& A, AddModFactors& W, + Matrix& B, + Options const& opts) +{ + slate_assert(A.mt() == A.nt()); // square + slate_assert(B.mt() == A.mt()); + + // factorization + getrf_addmod(A, W, opts); + + // solve + getrs_addmod(W, B, opts); + + // todo: return value for errors? +} + +//------------------------------------------------------------------------------ +// Explicit instantiations. +template +void gesv_addmod( + Matrix& A, AddModFactors& W, + Matrix& B, + Options const& opts); + +template +void gesv_addmod( + Matrix& A, AddModFactors& W, + Matrix& B, + Options const& opts); + +template +void gesv_addmod< std::complex >( + Matrix< std::complex >& A, AddModFactors< std::complex >& W, + Matrix< std::complex >& B, + Options const& opts); + +template +void gesv_addmod< std::complex >( + Matrix< std::complex >& A, AddModFactors< std::complex >& W, + Matrix< std::complex >& B, + Options const& opts); + +} // namespace slate diff --git a/src/gesv_addmod_ir.cc b/src/gesv_addmod_ir.cc new file mode 100644 index 000000000..8d045bfe4 --- /dev/null +++ b/src/gesv_addmod_ir.cc @@ -0,0 +1,185 @@ +// Copyright (c) 2017-2022, University of Tennessee. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause +// This program is free software: you can redistribute it and/or modify it under +// the terms of the BSD 3-Clause license. See the accompanying LICENSE file. + +#include "slate/slate.hh" +#include "auxiliary/Debug.hh" +#include "slate/Matrix.hh" +#include "slate/Tile_blas.hh" +#include "internal/internal.hh" +#include "internal/internal_util.hh" + +namespace slate { + +template +void gesv_addmod_ir( Matrix& A, AddModFactors& W, + Matrix& B, + Matrix& X, + int& iter, + Options const& opts) +{ + using real_t = blas::real_type; + + Target target = get_option( opts, Option::Target, Target::HostTask ); + + // Most routines prefer column major + const Layout layout = Layout::ColMajor; + const scalar_t one = 1.0; + const real_t eps = std::numeric_limits::epsilon(); + + int64_t itermax = get_option( opts, Option::MaxIterations, 30 ); + double tol = get_option( opts, Option::Tolerance, eps*std::sqrt(A.m()) ); + bool use_fallback = get_option( opts, Option::UseFallbackSolver, true ); + + assert( B.mt() == A.mt() ); + + // workspace + auto R = B.emptyLike(); + auto A_lu = A.template emptyLike(); + + // insert local tiles + R. insertLocalTiles( target ); + A_lu.insertLocalTiles( target ); + + if (target == Target::Devices && itermax != 0) { + #pragma omp parallel + #pragma omp master + #pragma omp taskgroup + { + #pragma omp task slate_omp_default_none \ + shared( A ) firstprivate( layout ) + { + A.tileGetAndHoldAllOnDevices( LayoutConvert( layout ) ); + } + #pragma omp task slate_omp_default_none \ + shared( B ) firstprivate( layout ) + { + B.tileGetAndHoldAllOnDevices( LayoutConvert( layout ) ); + } + #pragma omp task slate_omp_default_none \ + shared( X ) firstprivate( layout ) + { + X.tileGetAndHoldAllOnDevices( LayoutConvert( layout ) ); + } + } + } + + std::vector colnorms_X( X.n() ); + std::vector colnorms_R( R.n() ); + + // stopping criteria + real_t Anorm = norm( Norm::Inf, A, opts ); + real_t cte = Anorm*tol; + bool converged = false; + + // Compute the LU factorization of A_lo. + slate::copy( A, A_lu, opts ); + getrf_addmod( A_lu, W, opts ); + + + // Solve the system A_lu * X = B. + slate::copy( B, X, opts ); + getrs_addmod( W, X, opts ); + + if (itermax == 0) { + return; + } + + // compute r = b - a * x. + slate::copy( B, R, opts ); + gemm( + -one, A, + X, + one, R, opts ); + + // Check whether the nrhs normwise backward error satisfies the + // stopping criterion. If yes, set iter=0 and return. + colNorms( Norm::Max, X, colnorms_X.data(), opts ); + colNorms( Norm::Max, R, colnorms_R.data(), opts ); + + if (internal::iterRefConverged( colnorms_R, colnorms_X, cte )) { + iter = 0; + converged = true; + } + + // iterative refinement + for (int iiter = 0; iiter < itermax && ! converged; iiter++) { + // Solve the system A_lu * X = R. + getrs_addmod( W, R, opts ); + + // Update the current iterate. + add( + one, R, + one, X, opts ); + + // Compute R = B - A * X. + slate::copy( B, R, opts ); + gemm( + -one, A, + X, + one, R, opts ); + + // Check whether nrhs normwise backward error satisfies the + // stopping criterion. If yes, set iter = iiter > 0 and return. + colNorms( Norm::Max, X, colnorms_X.data(), opts ); + colNorms( Norm::Max, R, colnorms_R.data(), opts ); + + if (internal::iterRefConverged( colnorms_R, colnorms_X, cte )) { + iter = iiter+1; + converged = true; + } + } + + if (! converged) { + // If we are at this place of the code, this is because we have performed + // iter = itermax iterations and never satisfied the stopping criterion, + // set up the iter flag accordingly and follow up with double precision + // routine. + iter = -itermax - 1; + + if (use_fallback) { + slate::copy( B, X, opts ); + Pivots pivots; + gesv( A, pivots, X, opts ); + } + } + + // todo: return value for errors? +} + +//------------------------------------------------------------------------------ +// Explicit instantiations. +template +void gesv_addmod_ir( + Matrix& A, AddModFactors& W, + Matrix& B, + Matrix& X, + int& iter, + Options const& opts); + +template +void gesv_addmod_ir( + Matrix& A, AddModFactors& W, + Matrix& B, + Matrix& X, + int& iter, + Options const& opts); + +template +void gesv_addmod_ir< std::complex >( + Matrix< std::complex >& A, AddModFactors< std::complex >& W, + Matrix< std::complex >& B, + Matrix< std::complex >& X, + int& iter, + Options const& opts); + +template +void gesv_addmod_ir< std::complex >( + Matrix< std::complex >& A, AddModFactors< std::complex >& W, + Matrix< std::complex >& B, + Matrix< std::complex >& X, + int& iter, + Options const& opts); + +} // namespace slate diff --git a/src/gesv_rbt.cc b/src/gesv_rbt.cc index 70b4d39a3..dd2b83da0 100644 --- a/src/gesv_rbt.cc +++ b/src/gesv_rbt.cc @@ -159,6 +159,7 @@ void gesv_rbt(Matrix& A, std::vector colnorms_X( X.n() ); std::vector colnorms_R( R.n() ); + // stopping criteria real_t cte = Anorm*tol; bool converged = false; diff --git a/src/getrf_addmod.cc b/src/getrf_addmod.cc new file mode 100644 index 000000000..4e9fc5ef9 --- /dev/null +++ b/src/getrf_addmod.cc @@ -0,0 +1,548 @@ +// Copyright (c) 2022-2023, University of Tennessee. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause +// This program is free software: you can redistribute it and/or modify it under +// the terms of the BSD 3-Clause license. See the accompanying LICENSE file. + +#include "slate/addmod.hh" +#include "slate/slate.hh" +#include "auxiliary/Debug.hh" +#include "slate/Matrix.hh" +#include "slate/Tile_blas.hh" +#include "slate/TriangularMatrix.hh" +#include "internal/internal.hh" + +namespace slate { + +namespace internal { + +//------------------------------------------------------------------------------ +/// Distributed parallel LU factorization with additive modifications +/// Generic implementation for any target. +/// Panel and lookahead computed on host using Host OpenMP task. +/// @ingroup gesv_specialization +/// +template +void getrf_addmod(Matrix& A, AddModFactors& W, + Options const& opts) +{ + using BcastList = typename Matrix::BcastList; + using BcastListTag = typename Matrix::BcastListTag; + using real_t = blas::real_type; + + Layout layout = Layout::ColMajor; + + const scalar_t one = 1.0; + const scalar_t zero = 0.0; + + int64_t lookahead = get_option( opts, Option::Lookahead, 1 ); + int64_t ib = get_option( opts, Option::InnerBlocking, 16 ); + real_t mod_tol = get_option( opts, Option::AdditiveTolerance, -1e-8 ); + bool useWoodbury = get_option( opts, Option::UseWoodbury, 1 ); + BlockFactor blockFactorType = get_option( opts, Option::BlockFactor, BlockFactor::SVD ); + + + if (mod_tol < 0) { + // When Target::Device, we don't want norm to move tiles back from device + // So, we set the hold here to prevent norm from removing the device copy at the end + // Then use tileUnsetHoldAllOnDevice to remove the hold but not the device copy + if (target == Target::Devices) { + #pragma omp parallel + #pragma omp master + #pragma omp taskgroup + { + #pragma omp task slate_omp_default_none \ + shared( A ) firstprivate( layout ) + { + A.tileGetAndHoldAllOnDevices( LayoutConvert( layout ) ); + } + } + } + + mod_tol *= -1 * slate::norm(slate::Norm::Fro, A, opts); + + if (target == Target::Devices) { + A.tileUnsetHoldAllOnDevices(); + } + } + + + if (target == Target::Devices) { + // two batch arrays plus one for each lookahead + // batch array size will be set as needed + A.allocateBatchArrays(0, 2 + lookahead); + A.reserveDeviceWorkspace(); + } + + MPI_Comm comm = A.mpiComm(); + const MPI_Datatype mpi_scalar_t = mpi_type::value; + + const int priority_one = 1; + const int priority_zero = 0; + int64_t A_nt = A.nt(); + int64_t A_mt = A.mt(); + int64_t min_mt_nt = std::min(A.mt(), A.nt()); + + W.block_size = ib; + W.factorType = blockFactorType; + W.A = A; + W.U_factors = A.emptyLike(); + W.VT_factors = A.emptyLike(); + W.singular_values.resize(min_mt_nt); + W.modifications.resize(min_mt_nt); + W.modification_indices.resize(min_mt_nt); + + // OpenMP needs pointer types, but vectors are exception safe + std::vector< uint8_t > column_vector(A_nt); + std::vector< uint8_t > diag_vector(A_nt+1); + uint8_t* column = column_vector.data(); + uint8_t* diag = diag_vector.data(); + // Running two listBcastMT's simultaneously can hang due to task ordering + // This dependency avoids that + uint8_t listBcastMT_token; + SLATE_UNUSED(listBcastMT_token); // Only used by OpenMP + + // set min number for omp nested active parallel regions + slate::OmpSetMaxActiveLevels set_active_levels( MinOmpActiveLevels ); + + #pragma omp parallel + #pragma omp master + { + for (int64_t k = 0; k < min_mt_nt; ++k) { + + // panel, high priority + #pragma omp task depend(inout:column[k]) \ + depend(out:diag[k]) \ + priority(priority_one) + { + auto& local_sig_vals = W.singular_values[k]; + auto& local_mod_vals = W.modifications[k]; + auto& local_mod_inds = W.modification_indices[k]; + + int64_t mb = A.tileMb(k); + if (blockFactorType == BlockFactor::SVD) { + local_sig_vals.resize(mb); + } + + if (A.tileIsLocal(k,k)) { + W.U_factors.tileInsert(k, k, HostNum); + W.VT_factors.tileInsert(k, k, HostNum); + } + + // factor A(k, k) + internal::getrf_addmod( + A.sub(k, k, k, k), + W.U_factors.sub(k, k, k, k), + W.VT_factors.sub(k, k, k, k), + std::move(local_sig_vals), + std::move(local_mod_vals), std::move(local_mod_inds), + blockFactorType, mod_tol, ib); + + // broadcast singular values + MPI_Request requests[2] = {MPI_REQUEST_NULL, MPI_REQUEST_NULL}; + int64_t num_mods = local_mod_vals.size(); + if (blockFactorType == BlockFactor::SVD) { + slate_mpi_call( + MPI_Ibcast(local_sig_vals.data(), mb, mpi_scalar_t, + A.tileRank(k, k), comm, &requests[0]) ); + } + slate_mpi_call( + MPI_Ibcast(&num_mods, 1, MPI_INT64_T, + A.tileRank(k, k), comm, &requests[1]) ); + + // Update panel + int tag_k = k; + BcastList bcast_list_A, bcast_list_U, bcast_list_VT; + bcast_list_A.push_back({k, k, {A.sub(k+1, A_mt-1, k, k), + A.sub(k, k, k+1, A_nt-1)}}); + bcast_list_U.push_back({k, k, {A.sub(k, k, k+1, A_nt-1)}}); + bcast_list_VT.push_back({k, k, {A.sub(k+1, A_mt-1, k, k)}}); + + A.template listBcast( + bcast_list_A, layout, tag_k ); + W.U_factors.template listBcast( + bcast_list_U, layout, tag_k ); + if (blockFactorType != BlockFactor::QR) { + W.VT_factors.template listBcast( + bcast_list_VT, layout, tag_k ); + } + + // Allow concurrent Bcast's + slate_mpi_call( MPI_Waitall(2, requests, MPI_STATUSES_IGNORE) ); + local_mod_vals.resize(num_mods); + local_mod_inds.resize(num_mods); + } + + #pragma omp task depend(in:diag[k]) depend(in:diag[k+1]) \ + priority(priority_one) + { + auto& local_mod_inds = W.modification_indices[k]; + auto& local_mod_vals = W.modifications[k]; + + int64_t kk_root = A.tileRank(k,k); + int64_t num_mods = local_mod_vals.size(); + if (num_mods != 0) { + + MPI_Request requests[2]; + slate_mpi_call( + MPI_Ibcast(local_mod_vals.data(), num_mods, mpi_scalar_t, + kk_root, comm, &requests[0]) ); + slate_mpi_call( + MPI_Ibcast(local_mod_inds.data(), num_mods, MPI_INT64_T, + kk_root, comm, &requests[1]) ); + + slate_mpi_call( MPI_Waitall(2, requests, MPI_STATUSES_IGNORE) ); + } + } + + #pragma omp task depend(inout:column[k]) \ + depend(in:diag[k]) \ + depend(inout:listBcastMT_token) \ + priority(priority_one) + { + + internal::trsm_addmod( + Side::Right, Uplo::Upper, + scalar_t(1.0), A.sub(k, k, k, k), + W.U_factors.sub(k, k, k, k), + W.VT_factors.sub(k, k, k, k), + std::move(W.singular_values[k]), + A.sub(k+1, A_mt-1, k, k), + blockFactorType, ib, priority_one, layout, 0); + + + BcastListTag bcast_list; + // bcast the tiles of the panel to the trailing matrix + for (int64_t i = k+1; i < A_mt; ++i) { + // send A(i, k) across row A(i, k+1:nt-1) + const int64_t tag = i; + bcast_list.push_back({i, k, {A.sub(i, i, k+1, A_nt-1)}, tag}); + } + A.template listBcastMT( + bcast_list, layout ); + } + // update lookahead column(s), high priority + for (int64_t j = k+1; j < k+1+lookahead && j < A_nt; ++j) { + #pragma omp task depend(in:diag[k]) \ + depend(inout:column[j]) \ + priority(priority_one) + { + int tag_j = j; + + // solve A(k, k) A(k, j) = A(k, j) + internal::trsm_addmod( + Side::Left, Uplo::Lower, + scalar_t(1.0), A.sub(k, k, k, k), + W.U_factors.sub(k, k, k, k), + W.VT_factors.sub(k, k, k, k), + std::move(W.singular_values[k]), + A.sub(k, k, j, j), + blockFactorType, ib, priority_one, layout, j-k+1); + + // send A(k, j) across column A(k+1:mt-1, j) + A.tileBcast(k, j, A.sub(k+1, A_mt-1, j, j), layout, tag_j); + } + + #pragma omp task depend(in:column[k]) \ + depend(inout:column[j]) \ + priority(priority_one) + { + // A(k+1:mt-1, j) -= A(k+1:mt-1, k) * A(k, j) + internal::gemm( + -one, A.sub(k+1, A_mt-1, k, k), + A.sub(k, k, j, j), + one, A.sub(k+1, A_mt-1, j, j), + layout, priority_one, j-k+1); + } + } + // update trailing submatrix, normal priority + if (k+1+lookahead < A_nt) { + #pragma omp task depend(in:diag[k]) \ + depend(inout:column[k+1+lookahead]) \ + depend(inout:column[A_nt-1]) \ + depend(inout:listBcastMT_token) + { + // solve A(k, k) A(k, kl+1:nt-1) = A(k, kl+1:nt-1) + internal::trsm_addmod( + Side::Left, Uplo::Lower, + scalar_t(1.0), A.sub(k, k, k, k), + W.U_factors.sub(k, k, k, k), + W.VT_factors.sub(k, k, k, k), + std::move(W.singular_values[k]), + A.sub(k, k, k+1+lookahead, A_nt-1), + blockFactorType, ib, priority_zero, layout, 1); + // send A(k, kl+1:A_nt-1) across A(k+1:mt-1, kl+1:nt-1) + BcastListTag bcast_list; + for (int64_t j = k+1+lookahead; j < A_nt; ++j) { + // send A(k, j) across column A(k+1:mt-1, j) + // tag must be distinct from sending left panel + const int64_t tag = j + A_mt; + bcast_list.push_back({k, j, {A.sub(k+1, A_mt-1, j, j)}, + tag}); + } + A.template listBcastMT( + bcast_list, layout); + } + + #pragma omp task depend(in:column[k]) \ + depend(inout:column[k+1+lookahead]) \ + depend(inout:column[A_nt-1]) + { + // A(k+1:mt-1, kl+1:nt-1) -= A(k+1:mt-1, k) * A(k, kl+1:nt-1) + internal::gemm( + -one, A.sub(k+1, A_mt-1, k, k), + A.sub(k, k, k+1+lookahead, A_nt-1), + one, A.sub(k+1, A_mt-1, k+1+lookahead, A_nt-1), + layout, priority_zero, 1); + } + } + + #pragma omp task depend(inout:column[k]) + { + auto left_panel = A.sub( k, A_mt-1, k, k ); + auto top_panel = A.sub( k, k, k+1, A_nt-1 ); + auto U_k = W.U_factors.sub( k, k, k, k ); + auto VT_k = (blockFactorType != BlockFactor::QR + ? W.VT_factors.sub( k, k, k, k ) + : W.VT_factors.sub( 0, -1, 0, -1 )); + + // Erase remote tiles on all devices including host + left_panel.releaseRemoteWorkspace(); + top_panel.releaseRemoteWorkspace(); + U_k.releaseRemoteWorkspace(); + VT_k.releaseRemoteWorkspace(); + + // Update the origin tiles before their + // workspace copies on devices are erased. + left_panel.tileUpdateAllOrigin(); + top_panel.tileUpdateAllOrigin(); + U_k.tileUpdateAllOrigin(); + VT_k.tileUpdateAllOrigin(); + + // Erase local workspace on devices. + left_panel.releaseLocalWorkspace(); + top_panel.releaseLocalWorkspace(); + U_k.releaseLocalWorkspace(); + VT_k.releaseLocalWorkspace(); + } + } + + #pragma omp taskwait + A.tileUpdateAllOrigin(); + } + + // Build and factor capacitance matrix, if needed + int64_t inner_dim = 0; + if (useWoodbury) { + for (int64_t i = 0; i < A_mt; ++i) { + inner_dim += W.modifications[i].size(); + } + } + W.num_modifications = inner_dim; + if (inner_dim > 0) { + auto A_tileMb = A.tileMbFunc(); + auto A_tileRank = A.tileRankFunc(); + auto A_tileDevice = A.tileDeviceFunc(); + + W.capacitance_matrix = Matrix(inner_dim, inner_dim, + A_tileMb, A_tileMb, + A_tileRank, A_tileDevice, A.mpiComm()); + W.capacitance_matrix.insertLocalTiles(); + W.S_VT_Rinv = Matrix(inner_dim, A.m(), A_tileMb, A_tileMb, + A_tileRank, A_tileDevice, A.mpiComm()); + W.S_VT_Rinv.insertLocalTiles(); + W.Linv_U = Matrix(A.n(), inner_dim, A_tileMb, A_tileMb, + A_tileRank, A_tileDevice, A.mpiComm()); + W.Linv_U.insertLocalTiles(); + + // Build S_VT_Rinv and Linv_U + // First, create sparse S_VT and U matrices + // Then, apply Rinv and Linv + #pragma omp parallel + #pragma omp master + { + internal::set(zero, zero, std::move(W.S_VT_Rinv)); + internal::set(zero, zero, std::move(W.Linv_U)); + #pragma omp taskwait + int64_t tile = 0, tile_offset = 0; + for (int64_t i = 0; i < A_mt; ++i) { + // NB The nonzeros in a row (resp column) are all within the same tile + auto mod_inds = W.modification_indices[i]; + auto mod_vals = W.modifications[i]; + int64_t num_mods = mod_inds.size(); + int64_t mod_offset = 0; + + while (mod_offset < num_mods) { + int64_t chunk = std::min(num_mods - mod_offset, + W.capacitance_matrix.tileMb(tile) - tile_offset); + + if (blockFactorType == BlockFactor::SVD + || blockFactorType == BlockFactor::QLP + || blockFactorType == BlockFactor::QRCP) { + + if (W.S_VT_Rinv.tileIsLocal(tile, i)) { + #pragma omp task firstprivate(tile, i, chunk, mod_offset) \ + firstprivate(mod_inds, mod_vals, num_mods) + { + W.VT_factors.tileRecv(i, i, A.tileRank(i, i), layout, 2*i); + W.VT_factors.tileGetForReading(i, i, LayoutConvert(layout)); + auto tile_VT = W.VT_factors(i, i); + W.S_VT_Rinv.tileGetForWriting(tile, i, LayoutConvert(layout)); + auto tile_CVT = W.S_VT_Rinv(tile, i); + + for (int ii = 0; ii < chunk; ++ii) { + auto s = mod_vals[mod_offset+ii]; + auto ind = mod_inds[mod_offset+ii]; + auto col_offset = (ind / ib)*ib; + + int64_t kb = std::min(A.tileNb(i)-col_offset, ib); + for (int jj = 0; jj < kb; ++jj) { + tile_CVT.at(tile_offset+ii, jj+col_offset) + = s*tile_VT(ind, jj+col_offset); + } + } + W.VT_factors.releaseRemoteWorkspaceTile(i, i); + } + } + else if (W.VT_factors.tileIsLocal(i, i)) { + #pragma omp task firstprivate(tile, i) + { + W.VT_factors.tileSend(i, i, W.S_VT_Rinv.tileRank(tile, i), 2*i); + } + } + } + else if (blockFactorType == BlockFactor::QR) { + if (W.S_VT_Rinv.tileIsLocal(tile, i)) { + #pragma omp task firstprivate(tile, i, chunk, mod_offset) \ + firstprivate(mod_inds, mod_vals, num_mods) + { + W.S_VT_Rinv.tileGetForWriting(tile, i, LayoutConvert(layout)); + auto tile_CVT = W.S_VT_Rinv(tile, i); + + for (int ii = 0; ii < chunk; ++ii) { + auto s = mod_vals[mod_offset+ii]; + auto ind = mod_inds[mod_offset+ii]; + + // tile_VT is (matrix-free) identity block + tile_CVT.at(tile_offset+ii, ind) = s; + } + } + } + } + else { + slate_not_implemented("Unsupported block factor"); + } + + if (blockFactorType == BlockFactor::SVD + || blockFactorType == BlockFactor::QLP + || blockFactorType == BlockFactor::QRCP + || blockFactorType == BlockFactor::QR) { + + if (W.Linv_U.tileIsLocal(i, tile)) { + #pragma omp task firstprivate(tile, i, chunk, mod_offset) \ + firstprivate(mod_inds, mod_vals, num_mods) + { + W.U_factors.tileRecv(i, i, W.U_factors.tileRank(i, i), layout, 2*i+1); + W.U_factors.tileGetForReading(i, i, LayoutConvert(layout)); + auto tile_U = W.U_factors(i, i); + W.Linv_U.tileGetForWriting(i, tile, LayoutConvert(layout)); + auto tile_CU = W.Linv_U(i, tile); + + for (int jj = 0; jj < chunk; ++jj) { + auto ind = mod_inds[mod_offset+jj]; + auto row_offset = (ind / ib)*ib; + + int64_t kb = std::min(A.tileMb(i)-row_offset, ib); + for (int ii = 0; ii < kb; ++ii) { + tile_CU.at(ii+row_offset, tile_offset+jj) + = tile_U(ii+row_offset, ind); + } + } + W.U_factors.releaseRemoteWorkspaceTile(i, i); + } + } + else if (W.U_factors.tileIsLocal(i, i)) { + #pragma omp task firstprivate(tile, i) + { + W.U_factors.tileSend(i, i, W.Linv_U.tileRank(i, tile), 2*i+1); + } + } + } + else { + slate_not_implemented("Unsupported block factor"); + } + + tile_offset += chunk; + if (tile_offset >= W.capacitance_matrix.tileMb(tile)) { + tile += 1; + tile_offset = 0; + } + + mod_offset += chunk; + } + } + } + + trsm_addmod(Side::Right, Uplo::Upper, one, W, W.S_VT_Rinv, opts); + trsm_addmod(Side::Left, Uplo::Lower, one, W, W.Linv_U, opts); + + // build & factor capacitance matrix + set(zero, one, W.capacitance_matrix, opts); + gemm(-one, W.S_VT_Rinv, W.Linv_U, + one, W.capacitance_matrix, + opts); + + getrf(W.capacitance_matrix, W.capacitance_pivots, opts); + } + A.clearWorkspace(); +} + +} // namespace internal + +// TODO docs +template +void getrf_addmod(Matrix& A, AddModFactors& W, + Options const& opts) +{ + Target target = get_option( opts, Option::Target, Target::HostTask ); + + switch (target) { + case Target::Host: + case Target::HostTask: + internal::getrf_addmod(A, W, opts); + break; + case Target::HostNest: + internal::getrf_addmod(A, W, opts); + break; + case Target::HostBatch: + internal::getrf_addmod(A, W, opts); + break; + case Target::Devices: + internal::getrf_addmod(A, W, opts); + break; + } + // todo: return value for errors? +} + +//------------------------------------------------------------------------------ +// Explicit instantiations. +template +void getrf_addmod( + Matrix& A, AddModFactors& W, + Options const& opts); + +template +void getrf_addmod( + Matrix& A, AddModFactors& W, + Options const& opts); + +template +void getrf_addmod< std::complex >( + Matrix< std::complex >& A, AddModFactors< std::complex >& W, + Options const& opts); + +template +void getrf_addmod< std::complex >( + Matrix< std::complex >& A, AddModFactors< std::complex >& W, + Options const& opts); + +} // namespace slate diff --git a/src/getrs_addmod.cc b/src/getrs_addmod.cc new file mode 100644 index 000000000..043f256cf --- /dev/null +++ b/src/getrs_addmod.cc @@ -0,0 +1,93 @@ +// Copyright (c) 2022-2023, University of Tennessee. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause +// This program is free software: you can redistribute it and/or modify it under +// the terms of the BSD 3-Clause license. See the accompanying LICENSE file. + +#include "slate/slate.hh" +#include "auxiliary/Debug.hh" +#include "slate/Matrix.hh" +#include "slate/Tile_blas.hh" +#include "slate/TriangularMatrix.hh" +#include "internal/internal.hh" + +namespace slate { + +// TODO docs +template +void getrs_addmod(AddModFactors& W, + Matrix& B, + Options const& opts) +{ + // Constants + const scalar_t one = 1; + const scalar_t zero = 0; + + Matrix& A = W.A; + + assert(A.mt() == A.nt()); + assert(B.mt() == A.mt()); + + if (A.op() != Op::NoTrans) { + slate_not_implemented("Transposed matrices not yet supported"); + } + + + // Forward substitution, Y = L^{-1} P B. + trsm_addmod(Side::Left, Uplo::Lower, scalar_t(1.0), W, B, opts); + + // Woodbury correction + if (W.num_modifications > 0) { + // Create temporary vector for Woodbury formula + auto A_tileMb = W.A.tileMbFunc(); + auto A_tileRank = W.A.tileRankFunc(); + auto A_tileDevice = W.A.tileDeviceFunc(); + Matrix temp(W.num_modifications, B.n(), + A_tileMb, A_tileMb, + A_tileRank, A_tileDevice, W.A.mpiComm()); + temp.insertLocalTiles(); + + + gemm(one, W.S_VT_Rinv, B, + zero, temp, + opts); + + getrs(W.capacitance_matrix, W.capacitance_pivots, temp, opts); + + gemm(one, W.Linv_U, temp, + one, B, + opts); + } + + // Backward substitution, X = U^{-1} Y. + trsm_addmod(Side::Left, Uplo::Upper, scalar_t(1.0), W, B, opts); + + // todo: return value for errors? +} + +//------------------------------------------------------------------------------ +// Explicit instantiations. +template +void getrs_addmod( + AddModFactors& W, + Matrix& B, + Options const& opts); + +template +void getrs_addmod( + AddModFactors& W, + Matrix& B, + Options const& opts); + +template +void getrs_addmod< std::complex >( + AddModFactors< std::complex >& W, + Matrix< std::complex >& B, + Options const& opts); + +template +void getrs_addmod< std::complex >( + AddModFactors< std::complex >& W, + Matrix< std::complex >& B, + Options const& opts); + +} // namespace slate diff --git a/src/internal/Tile_getrf_addmod.hh b/src/internal/Tile_getrf_addmod.hh new file mode 100644 index 000000000..1c47eb6cc --- /dev/null +++ b/src/internal/Tile_getrf_addmod.hh @@ -0,0 +1,288 @@ +// Copyright (c) 2022-2023, University of Tennessee. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause +// This program is free software: you can redistribute it and/or modify it under +// the terms of the BSD 3-Clause license. See the accompanying LICENSE file. + +#ifndef SLATE_TILE_GETRF_ADDMOD_HH +#define SLATE_TILE_GETRF_ADDMOD_HH + +#include "internal/internal.hh" +#include "slate/Tile.hh" +#include "slate/types.hh" +#include "slate/enums.hh" + +#include +#include +#include + +namespace slate { +namespace internal { + + +template +scalar_t phase(scalar_t z) { + if constexpr (is_complex::value) { + return z == scalar_t(0.0) ? scalar_t(1.0) : z / std::abs(z); + } + else { + return z >= scalar_t(0.0) ? scalar_t(1.0) : scalar_t(-1.0); + } +} + + +//------------------------------------------------------------------------------ +/// Compute the LU factorization of a tile with additive corrections. +/// +/// @param[in,out] A +/// tile to factor +/// +/// @param[out] U +/// right singular vectors +/// +/// @param[out] singular_values +/// modified singular values +/// +/// @param[out] modifications +/// amount the singular values were modified by +/// +/// @param[out] modified_indices +/// the singular values that correspond to the elements in modifications +/// +/// +/// @param[in] ib +/// blocking factor used in the factorization +/// +/// @ingroup gesv_tile +/// +template +void getrf_addmod(Tile A, + Tile U, + Tile VT, + std::vector< blas::real_type >& singular_values, + std::vector& modifications, + std::vector& modified_indices, + blas::real_type mod_tol, + int64_t ib) +{ + slate_assert(A.layout() == Layout::ColMajor); + + using real_t = blas::real_type; + + const scalar_t one = 1.0; + const scalar_t zero = 0.0; + int64_t mb = A.mb(); + int64_t nb = A.nb(); + int64_t diag_len = std::min(nb, mb); + + scalar_t* A_data = A.data(); + int64_t lda = A.stride(); + scalar_t* U_data = U.data(); + int64_t ldu = U.stride(); + scalar_t* VT_data = VT.data(); + int64_t ldvt = VT.stride(); + + std::vector workspace_vect (ib * std::max(INT64_C(2), std::max(mb, nb))); + scalar_t* workspace = workspace_vect.data(); + + // Used for pivoting + std::vector iworkspace_vect (2*ib); + int64_t* iworkspace = iworkspace_vect.data(); + + using singular_t = std::conditional_t; + + // Loop over ib-wide blocks.. + for (int64_t k = 0; k < diag_len; k += ib) { + int64_t kb = std::min(diag_len-k, ib); + + singular_t* s_vals; + int64_t s_inc; + + scalar_t* Akk = & A_data[k + k*lda]; + scalar_t* Ukk = & U_data[k + k*ldu]; + [[maybe_unused]] scalar_t* VTkk = &VT_data[k + k*ldvt]; + + if constexpr (factorType == BlockFactor::SVD) { + // Compute SVD of diagonal block + //lapack::gesvd(lapack::Job::AllVec, lapack::Job::AllVec, kb, kb, + // Akk, lda, + // &singular_values[k], + // Ukk, ldu, + // VTkk, ldvt); + lapack::gesdd(lapack::Job::AllVec, kb, kb, + Akk, lda, + &singular_values[k], + Ukk, ldu, + VTkk, ldvt); + s_vals = &singular_values[k]; + s_inc = 1; + } + else if constexpr (factorType == BlockFactor::QLP) { + // iworkspace must be zeroed out for geqp3 + for (int64_t i = 0; i < kb; ++i) { + iworkspace[i] = 0; + } + lapack::geqp3(kb, kb, Akk, lda, iworkspace, workspace); + // copy and clear strictly lower-triangular part of Akk + for (int64_t j = 0; j < kb; ++j) { + for (int64_t i = j+1; i < kb; ++i) { + Ukk[i + j*ldu] = Akk[i + j*lda]; + Akk[i + j*lda] = zero; + } + } + // use the non-pivoted LQ since LAPACK doesn't have a pivoted one + lapack::gelqf(kb, kb, Akk, lda, workspace+kb); + + // build left unitary matrix + lapack::ungqr(kb, kb, kb, Ukk, ldu, workspace); + + // build right unitary matrix + lapack::lacpy(lapack::MatrixType::Upper, kb, kb, + Akk, lda, VTkk, ldvt); + lapack::unglq(kb, kb, kb, VTkk, ldvt, workspace+kb); + // geqp3 provides pivots as the final order, requiring an out-of-place permutation + for (int64_t h = 0; h < kb; ++h) { + blas::copy(kb, VTkk+h*ldvt, 1, workspace+(iworkspace[h]-1)*kb, 1); + } + lapack::lacpy(lapack::MatrixType::General, kb, kb, + workspace, kb, VTkk, ldvt); + + s_vals = Akk; + s_inc = lda+1; + } + else if constexpr (factorType == BlockFactor::QRCP) { + // iworkspace must be zeroed out for geqp3 + for (int64_t i = 0; i < kb; ++i) { + iworkspace[i] = 0; + } + lapack::geqp3(kb, kb, Akk, lda, iworkspace, workspace); + lapack::lacpy(lapack::MatrixType::Lower, kb, kb, + Akk, lda, Ukk, ldu); + lapack::ungqr(kb, kb, kb, Ukk, ldu, workspace); + + // TODO just store the pivots instead of building a permutation matrix + lapack::laset(lapack::MatrixType::General, kb, kb, zero, zero, VTkk, ldvt); + for (int64_t h = 0; h < kb; ++h) { + VTkk[h + (iworkspace[h]-1)*ldvt] = one; + } + + s_vals = Akk; + s_inc = lda+1; + } + else if constexpr (factorType == BlockFactor::QR) { + lapack::geqrf(kb, kb, Akk, lda, workspace); + lapack::lacpy(lapack::MatrixType::Lower, kb, kb, + Akk, lda, Ukk, ldu); + lapack::ungqr(kb, kb, kb, Ukk, ldu, workspace); + + s_vals = Akk; + s_inc = lda+1; + } + else { + // static_assert must depend on factorType + static_assert(factorType == BlockFactor::SVD, "Not yet implemented"); + } + + // Compute modifications + // TODO consider optimizing when s_vals guaranteed to be ordered or non-negative + for (int64_t i = 0; i < kb; ++i) { + if (std::abs(s_vals[i*s_inc]) <= mod_tol) { + singular_t target = phase(s_vals[i*s_inc])*mod_tol; + singular_t mod = target - s_vals[i*s_inc]; + s_vals[i*s_inc] = target; + + modifications.push_back(mod); + modified_indices.push_back(i); + } + } + + // block-column update: A := A R^-1 + if (k+kb < mb) { + scalar_t* A_panel = &A_data[k+kb + k*lda]; + int64_t ldwork = mb-k-kb; + + if constexpr (factorType == BlockFactor::SVD) { + blas::gemm(Layout::ColMajor, + Op::NoTrans, Op::ConjTrans, + mb-k-kb, kb, kb, + one, A_panel, lda, + VTkk, ldvt, + zero, workspace, ldwork); + + for (int64_t j = 0; j < kb; ++j) { + scalar_t Sj = singular_values[k+j]; + for (int64_t i = 0; i < mb-k-kb; ++i) { + A_panel[i + j*lda] = workspace[i + j*ldwork] / Sj; + } + } + } + else if constexpr (factorType == BlockFactor::QLP + || factorType == BlockFactor::QRCP) { + // TODO redo this as a simple permutation for QRCP + blas::gemm(Layout::ColMajor, + Op::NoTrans, Op::ConjTrans, + mb-k-kb, kb, kb, + one, A_panel, lda, + VTkk, ldvt, + zero, workspace, ldwork); + lapack::lacpy(lapack::MatrixType::General, mb-k-kb, kb, workspace, ldwork, A_panel, lda); + + auto uplo = factorType == BlockFactor::QLP ? Uplo::Lower : Uplo::Upper; + blas::trsm(Layout::ColMajor, Side::Right, uplo, + Op::NoTrans, Diag::NonUnit, mb-k-kb, kb, + one, Akk, lda, A_panel, lda); + } + else if constexpr (factorType == BlockFactor::QR) { + blas::trsm(Layout::ColMajor, Side::Right, Uplo::Upper, + Op::NoTrans, Diag::NonUnit, mb-k-kb, kb, + one, Akk, lda, A_panel, lda); + } + else { + // static_assert must depend on factorType + static_assert(factorType == BlockFactor::SVD, "Not yet implemented"); + } + } + + // block-row update: A := L^-1 A + if (k+kb < nb) { + scalar_t* A_panel = &A_data[k + (k+kb)*lda]; + + if constexpr (factorType == BlockFactor::SVD + || factorType == BlockFactor::QLP + || factorType == BlockFactor::QRCP + || factorType == BlockFactor::QR) { + for (int64_t j = 0; j < nb-k-kb; ++j) { + for (int64_t i = 0; i < kb; ++i) { + workspace[i + j*kb] = A_panel[i + j*lda]; + } + } + + blas::gemm(Layout::ColMajor, + Op::ConjTrans, Op::NoTrans, + kb, nb-k-kb, kb, + one, Ukk, ldu, + workspace, kb, + zero, A_panel, lda); + } + else { + // static_assert must depend on factorType + static_assert(factorType == BlockFactor::SVD, "Not yet implemented"); + } + } + + // trailing matrix update + if (k+kb < mb && k+kb < nb) { + blas::gemm(Layout::ColMajor, + Op::NoTrans, Op::NoTrans, + mb-k-kb, nb-k-kb, kb, + -one, &A_data[k+kb + (k )*lda], lda, + &A_data[k + (k+kb)*lda], lda, + one, &A_data[k+kb + (k+kb)*lda], lda); + } + } +} + +} // namespace internal +} // namespace slate + +#endif // SLATE_TILE_GETRF_ADDMOD_HH diff --git a/src/internal/Tile_trsm_addmod.hh b/src/internal/Tile_trsm_addmod.hh new file mode 100644 index 000000000..94b83621b --- /dev/null +++ b/src/internal/Tile_trsm_addmod.hh @@ -0,0 +1,472 @@ +// Copyright (c) 2022, University of Tennessee. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause +// This program is free software: you can redistribute it and/or modify it under +// the terms of the BSD 3-Clause license. See the accompanying LICENSE file. + +#ifndef SLATE_TILE_TRSM_ADDMOD_HH +#define SLATE_TILE_TRSM_ADDMOD_HH + +#include "internal/internal.hh" +#include "slate/Tile.hh" +#include "slate/types.hh" +#include "slate/enums.hh" + +#include +#include + +namespace slate { +namespace tile { + +template +scalar_t* matrix_offset(Layout layout, scalar_t* A, int64_t lda, int64_t i, int64_t j) +{ + if (layout == Layout::ColMajor) { + return A + i + j*lda; + } + else { + return A + i*lda + j; + } +} + +template +void diag_solve(Layout layout, Side side, int64_t mb, int64_t nb, + blas::real_type* S, + scalar_t* B, int64_t ldb, + scalar_t* C, int64_t ldc) +{ + if (side == Side::Left) { + if (layout == Layout::ColMajor) { + assert(ldb >= mb); + assert(ldc >= mb); + for (int64_t j = 0; j < nb; j++) { + for (int64_t i = 0; i < mb; i++) { + C[i + j*ldc] = B[i + j*ldb] / S[i]; + } + } + } + else { + assert(ldb >= nb); + assert(ldc >= nb); + for (int64_t i = 0; i < mb; i++) { + for (int64_t j = 0; j < nb; j++) { + C[i*ldc + j] = B[i*ldb + j] / S[i]; + } + } + } + } + else { + if (layout == Layout::ColMajor) { + assert(ldb >= mb); + assert(ldc >= mb); + for (int64_t j = 0; j < nb; j++) { + for (int64_t i = 0; i < mb; i++) { + C[i + j*ldc] = B[i + j*ldb] / S[j]; + } + } + } + else { + assert(ldb >= nb); + assert(ldc >= nb); + for (int64_t i = 0; i < mb; i++) { + for (int64_t j = 0; j < nb; j++) { + C[i*ldc + j] = B[i*ldb + j] / S[j]; + } + } + } + } +} + +template +void lacpy_layout(Layout layout, int64_t mb, int64_t nb, + scalar_t A, int64_t lda, scalar_t B, int64_t ldb) { + if (layout == Layout::ColMajor) { + lapack::lacpy(lapack::MatrixType::General, mb, nb, + A, lda, B, ldb); + } + else { + lapack::lacpy(lapack::MatrixType::General, nb, mb, + A, lda, B, ldb); + } +} + +template +void trsm_addmod_recur_lower_left(int64_t ib, Layout layout, + int64_t mb, int64_t nb, + scalar_t alpha, + scalar_t* A, int64_t lda, + scalar_t* U, int64_t ldu, + scalar_t* B, int64_t ldb, + scalar_t* work, int64_t ldwork) + +{ + const scalar_t one = 1.0; + [[maybe_unused]] + const scalar_t zero = 0.0; + + if (mb <= ib) { + // halt recursion + if constexpr (factorType == BlockFactor::SVD + || factorType == BlockFactor::QLP + || factorType == BlockFactor::QRCP + || factorType == BlockFactor::QR) { + blas::gemm(layout, + Op::ConjTrans, Op::NoTrans, + mb, nb, mb, + alpha, U, ldu, + B, ldb, + zero, work, ldwork); + + lacpy_layout(layout, mb, nb, + work, ldwork, B, ldb); + } + else { + slate_not_implemented( "Block factorization not implemented" ); + } + } + else { + int64_t m1 = (((mb-1)/ib)/2+1) * ib; // half the tiles, rounded up + int64_t m2 = mb-m1; + + trsm_addmod_recur_lower_left(ib, layout, m1, nb, alpha, + A, lda, + U, ldu, + B, ldb, + work, ldwork); + + blas::gemm(layout, Op::NoTrans, Op::NoTrans, m2, nb, m1, + -one, matrix_offset(layout, A, lda, m1, 0), lda, + matrix_offset(layout, B, ldb, 0, 0), ldb, + alpha, matrix_offset(layout, B, ldb, m1, 0), ldb); + + trsm_addmod_recur_lower_left(ib, layout, m2, nb, one, + matrix_offset(layout, A, lda, m1, m1), lda, + matrix_offset(layout, U, ldu, m1, m1), ldu, + matrix_offset(layout, B, ldb, m1, 0), ldb, + work, ldwork); + } +} + +template +void trsm_addmod_recur_upper_left(int64_t ib, Layout layout, + int64_t mb, int64_t nb, + scalar_t alpha, + scalar_t* A, int64_t lda, + scalar_t* VT, int64_t ldvt, + blas::real_type* S, + scalar_t* B, int64_t ldb, + scalar_t* work, int64_t ldwork) + +{ + const scalar_t one = 1.0; + [[maybe_unused]] + const scalar_t zero = 0.0; + + if (mb <= ib) { + // halt recursion + if constexpr (factorType == BlockFactor::SVD) { + diag_solve(layout, Side::Left, mb, nb, + S, B, ldb, work, ldwork); + + blas::gemm(layout, + Op::ConjTrans, Op::NoTrans, + mb, nb, mb, + alpha, VT, ldvt, + work, ldwork, + zero, B, ldb); + } + else if constexpr (factorType == BlockFactor::QLP + || factorType == BlockFactor::QRCP) { + auto uplo = factorType == BlockFactor::QLP ? Uplo::Lower : Uplo::Upper; + blas::trsm(layout, Side::Left, uplo, + Op::NoTrans, Diag::NonUnit, mb, nb, + alpha, A, lda, B, ldb); + + blas::gemm(layout, + Op::ConjTrans, Op::NoTrans, + mb, nb, mb, + alpha, VT, ldvt, + B, ldb, + zero, work, ldwork); + lacpy_layout(layout, mb, nb, + work, ldwork, B, ldb); + } + else if constexpr (factorType == BlockFactor::QR) { + blas::trsm(layout, Side::Left, Uplo::Upper, + Op::NoTrans, Diag::NonUnit, mb, nb, + alpha, A, lda, B, ldb); + } + else { + slate_not_implemented( "Block factorization not implemented" ); + } + } + else { + int64_t m1 = (((mb-1)/ib+1)/2) * ib; // half the tiles, rounded down + int64_t m2 = mb-m1; + + trsm_addmod_recur_upper_left(ib, layout, m2, nb, alpha, + matrix_offset(layout, A, lda, m1, m1), lda, + matrix_offset(layout, VT, ldvt, m1, m1), ldvt, + S + m1, + matrix_offset(layout, B, ldb, m1, 0), ldb, + work, ldwork); + + blas::gemm(layout, Op::NoTrans, Op::NoTrans, m1, nb, m2, + -one, matrix_offset(layout, A, lda, 0, m1), lda, + matrix_offset(layout, B, ldb, m1, 0), ldb, + alpha, matrix_offset(layout, B, ldb, 0, 0), ldb); + + trsm_addmod_recur_upper_left(ib, layout, m1, nb, one, + A, lda, + VT, ldvt, + S, + B, ldb, + work, ldwork); + } +} + +template +void trsm_addmod_recur_lower_right(int64_t ib, Layout layout, + int64_t mb, int64_t nb, + scalar_t alpha, + scalar_t* A, int64_t lda, + scalar_t* U, int64_t ldu, + scalar_t* B, int64_t ldb, + scalar_t* work, int64_t ldwork) + +{ + const scalar_t one = 1.0; + [[maybe_unused]] + const scalar_t zero = 0.0; + + if (nb <= ib) { + // halt recursion + if constexpr (factorType == BlockFactor::SVD + || factorType == BlockFactor::QLP + || factorType == BlockFactor::QRCP + || factorType == BlockFactor::QR) { + blas::gemm(layout, + Op::NoTrans, Op::ConjTrans, + mb, nb, nb, + alpha, B, ldb, + U, ldu, + zero, work, ldwork); + + lacpy_layout(layout, mb, nb, + work, ldwork, B, ldb); + } + else { + slate_not_implemented( "Block factorization not implemented" ); + } + } + else { + // recurse + int64_t n1 = (((nb-1)/ib+1)/2) * ib; // half the tiles, rounded down + int64_t n2 = nb-n1; + + trsm_addmod_recur_lower_right(ib, layout, mb, n2, alpha, + matrix_offset(layout, A, lda, n1, n1), lda, + matrix_offset(layout, U, ldu, n1, n1), ldu, + matrix_offset(layout, B, ldu, 0, n1), ldb, + work, ldwork); + + blas::gemm(layout, Op::NoTrans, Op::NoTrans, mb, n1, n2, + -one, matrix_offset(layout, B, ldb, 0, n1), ldb, + matrix_offset(layout, A, lda, n1, 0), lda, + alpha, matrix_offset(layout, B, ldb, 0, 0), ldb); + + trsm_addmod_recur_lower_right(ib, layout, mb, n1, one, + A, lda, + U, ldu, + B, ldb, + work, ldwork); + } +} + +template +void trsm_addmod_recur_upper_right(int64_t ib, Layout layout, + int64_t mb, int64_t nb, + scalar_t alpha, + scalar_t* A, int64_t lda, + scalar_t* VT, int64_t ldvt, + blas::real_type* S, + scalar_t* B, int64_t ldb, + scalar_t* work, int64_t ldwork) + +{ + const scalar_t one = 1.0; + [[maybe_unused]] + const scalar_t zero = 0.0; + + if (nb <= ib) { + // halt recursion + if constexpr (factorType == BlockFactor::SVD) { + blas::gemm(layout, + Op::NoTrans, Op::ConjTrans, + mb, nb, nb, + alpha, B, ldb, + VT, ldvt, + zero, work, ldwork); + + diag_solve(layout, Side::Right, mb, nb, + S, work, ldwork, B, ldb); + } + else if constexpr (factorType == BlockFactor::QLP + || factorType == BlockFactor::QRCP) { + blas::gemm(layout, + Op::NoTrans, Op::ConjTrans, + mb, nb, nb, + alpha, B, ldb, + VT, ldvt, + zero, work, ldwork); + lacpy_layout(layout, mb, nb, + work, ldwork, B, ldb); + + auto uplo = factorType == BlockFactor::QLP ? Uplo::Lower : Uplo::Upper; + blas::trsm(Layout::ColMajor, Side::Right, uplo, + Op::NoTrans, Diag::NonUnit, mb, nb, + alpha, A, lda, B, ldb); + } + else if constexpr (factorType == BlockFactor::QR) { + blas::trsm(Layout::ColMajor, Side::Right, Uplo::Upper, + Op::NoTrans, Diag::NonUnit, mb, nb, + alpha, A, lda, B, ldb); + } + else { + slate_not_implemented( "Block factorization not implemented" ); + } + } + else { + // recurse + int64_t n1 = (((nb-1)/ib)/2+1) * ib; // half the tiles, rounded up + int64_t n2 = nb-n1; + + trsm_addmod_recur_upper_right(ib, layout, mb, n1, alpha, + A, lda, + VT, ldvt, + S, + B, ldb, + work, ldwork); + + blas::gemm(layout, Op::NoTrans, Op::NoTrans, mb, n2, n1, + -one, matrix_offset(layout, B, ldb, 0, 0), ldb, + matrix_offset(layout, A, lda, 0, n1), lda, + alpha, matrix_offset(layout, B, ldb, 0, n1), ldb); + + trsm_addmod_recur_upper_right(ib, layout, mb, n2, one, + matrix_offset(layout, A, lda, n1, n1), lda, + matrix_offset(layout, VT, ldvt, n1, n1), ldvt, + S + n1, + matrix_offset(layout, B, ldb, 0, n1), ldb, + work, ldwork); + } +} + + +template +void trsm_addmod_helper(int64_t ib, Side side, Uplo uplo, scalar_t alpha, + Tile A, + Tile U, + Tile VT, + std::vector>& S, + Tile B) +{ + slate_assert(U.uploPhysical() == Uplo::General); + slate_assert(VT.uploPhysical() == Uplo::General); + slate_assert(B.uploPhysical() == Uplo::General); + slate_assert(U.layout() == B.layout()); + slate_assert(VT.layout() == B.layout()); + + slate_assert(A.mb() == A.nb()); + slate_assert(U.mb() == U.nb()); + slate_assert(A.mb() == U.mb()); + + blas::real_type* S_data = nullptr; + if (factorType == BlockFactor::SVD) { + slate_assert(A.mb() == int64_t(S.size())); + S_data = S.data(); + } + + int64_t mb = B.mb(), nb = B.nb(); + + // TODO this allocation can be made smaller + int64_t work_stride = B.layout() == Layout::ColMajor ? mb : nb; + std::vector work_vect(mb*nb); + scalar_t* work = work_vect.data(); + + if (uplo == Uplo::Lower) { + if (side == Side::Right) { + // Lower, Right + slate_assert(U.mb() == nb); + + trsm_addmod_recur_lower_right(ib, B.layout(), mb, nb, alpha, + A.data(), A.stride(), + U.data(), U.stride(), + B.data(), B.stride(), + work, work_stride); + } + else { + // Lower, Left + slate_assert(U.mb() == mb); + + trsm_addmod_recur_lower_left(ib, B.layout(), mb, nb, alpha, + A.data(), A.stride(), + U.data(), U.stride(), + B.data(), B.stride(), + work, work_stride); + } + } + else { + if (side == Side::Right) { + // Upper, Right + slate_assert(A.mb() == nb); + + trsm_addmod_recur_upper_right(ib, B.layout(), mb, nb, alpha, + A.data(), A.stride(), + VT.data(),VT.stride(), + S_data, + B.data(), B.stride(), + work, work_stride); + } + else { + // Upper, Left + slate_assert(A.mb() == mb); + + trsm_addmod_recur_upper_left(ib, B.layout(), mb, nb, alpha, + A.data(), A.stride(), + VT.data(),VT.stride(), + S_data, + B.data(), B.stride(), + work, work_stride); + } + } +} + +template +void trsm_addmod(BlockFactor factorType, int64_t ib, Side side, Uplo uplo, scalar_t alpha, + Tile A, + Tile U, + Tile VT, + std::vector>& S, + Tile B) +{ + if (factorType == BlockFactor::SVD) { + trsm_addmod_helper(ib, side, uplo, alpha, A, U, VT, S, B); + } + else if (factorType == BlockFactor::QLP) { + trsm_addmod_helper(ib, side, uplo, alpha, A, U, VT, S, B); + } + else if (factorType == BlockFactor::QRCP) { + trsm_addmod_helper(ib, side, uplo, alpha, A, U, VT, S, B); + } + else if (factorType == BlockFactor::QR) { + trsm_addmod_helper(ib, side, uplo, alpha, A, U, VT, S, B); + } + else { + slate_not_implemented( "Block factorization not implemented" ); + } +} + + +} // namespace internal +} // namespace slate +#endif // SLATE_TILE_TRSM_ADDMOD_HH diff --git a/src/internal/internal.hh b/src/internal/internal.hh index 20963ff3b..c8e3f16cc 100644 --- a/src/internal/internal.hh +++ b/src/internal/internal.hh @@ -349,6 +349,30 @@ void trsmA(Side side, int priority=0, Layout layout=Layout::ColMajor, int64_t queue_index=0 ); +//----------------------------------------- +// trsm_addmod() +template +void trsm_addmod(Side side, Uplo uplo, scalar_t alpha, + Matrix&& A, + Matrix&& U, + Matrix&& VT, + std::vector>&& S, + Matrix&& B, + BlockFactor blockFactorType, int64_t ib, + int priority=0, Layout layout=Layout::ColMajor, int64_t queue_index=0 ); + +//----------------------------------------- +// trsmA_addmod() +template +void trsmA_addmod(Side side, Uplo uplo, scalar_t alpha, + Matrix&& A, + Matrix&& U, + Matrix&& VT, + std::vector>&& S, + Matrix&& B, + BlockFactor blockFactorType, int64_t ib, + int priority=0, Layout layout=Layout::ColMajor, int64_t queue_index=0); + //----------------------------------------- // trtri() template @@ -513,6 +537,19 @@ void getrf_tntpiv_panel( std::vector& pivot, int max_panel_threads, int priority, int64_t* info ); +//----------------------------------------- +// getrf_addmod() +template +void getrf_addmod(Matrix< scalar_t >&& A, + Matrix< scalar_t >&& U, + Matrix< scalar_t >&& VT, + std::vector< blas::real_type >&& singular_values, + std::vector&& modifications, + std::vector&& modified_indices, + BlockFactor blockFactorType, + blas::real_type mod_tol, + int64_t ib); + //----------------------------------------- // geqrf() template diff --git a/src/internal/internal_getrf_addmod.cc b/src/internal/internal_getrf_addmod.cc new file mode 100644 index 000000000..70a4d6b16 --- /dev/null +++ b/src/internal/internal_getrf_addmod.cc @@ -0,0 +1,136 @@ +// Copyright (c) 2017-2020, University of Tennessee. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause +// This program is free software: you can redistribute it and/or modify it under +// the terms of the BSD 3-Clause license. See the accompanying LICENSE file. + +#include "slate/Matrix.hh" +#include "slate/types.hh" +#include "internal/Tile_getrf_addmod.hh" +#include "internal/internal.hh" + +namespace slate { +namespace internal { + +//------------------------------------------------------------------------------ +/// LU factorization of single tile with additive corrections. +/// Dispatches to target implementations. +/// @ingroup gesv_internal +/// +template +void getrf_addmod(Matrix< scalar_t >&& A, + Matrix< scalar_t >&& U, + Matrix< scalar_t >&& VT, + std::vector< blas::real_type >&& singular_values, + std::vector&& modifications, + std::vector&& modified_indices, + BlockFactor blockFactorType, + blas::real_type mod_tol, + int64_t ib) +{ + getrf_addmod(internal::TargetType(), + A, U, VT, singular_values, modifications, modified_indices, + blockFactorType, mod_tol, ib); +} + +//------------------------------------------------------------------------------ +/// LU factorization of single tile with additive corrections, host implementation. +/// @ingroup gesv_internal +/// +template +void getrf_addmod(internal::TargetType, + Matrix& A, + Matrix& U, + Matrix& VT, + std::vector< blas::real_type >& singular_values, + std::vector& modifications, + std::vector& modified_indices, + BlockFactor blockFactorType, + blas::real_type mod_tol, + int64_t ib) +{ + assert(A.mt() == 1); + assert(A.nt() == 1); + + if (A.tileIsLocal(0, 0)) { + A.tileGetForWriting(0, 0, LayoutConvert::ColMajor); + U.tileGetForWriting(0, 0, LayoutConvert::ColMajor); + VT.tileGetForWriting(0, 0, LayoutConvert::ColMajor); + + if (blockFactorType == BlockFactor::SVD) { + getrf_addmod(A(0, 0), U(0, 0), VT(0, 0), singular_values, + modifications, modified_indices, mod_tol, ib); + } + else if (blockFactorType == BlockFactor::QLP) { + getrf_addmod(A(0, 0), U(0, 0), VT(0, 0), singular_values, + modifications, modified_indices, mod_tol, ib); + } + else if (blockFactorType == BlockFactor::QRCP) { + getrf_addmod(A(0, 0), U(0, 0), VT(0, 0), singular_values, + modifications, modified_indices, mod_tol, ib); + } + else if (blockFactorType == BlockFactor::QR) { + getrf_addmod(A(0, 0), U(0, 0), VT(0, 0), singular_values, + modifications, modified_indices, mod_tol, ib); + } + else { + slate_not_implemented("Unsupported BlockFactor"); + } + } +} + +//------------------------------------------------------------------------------ +// Explicit instantiations. +// ---------------------------------------- +template +void getrf_addmod( + Matrix&& A, + Matrix&& U, + Matrix&& VT, + std::vector&& singular_values, + std::vector&& modifications, + std::vector&& modified_indices, + BlockFactor blockFactorType, + float mod_tol, + int64_t ib); + +// ---------------------------------------- +template +void getrf_addmod( + Matrix&& A, + Matrix&& U, + Matrix&& VT, + std::vector&& singular_values, + std::vector&& modifications, + std::vector&& modified_indices, + BlockFactor blockFactorType, + double mod_tol, + int64_t ib); + +// ---------------------------------------- +template +void getrf_addmod< Target::HostTask, std::complex >( + Matrix< std::complex >&& A, + Matrix< std::complex >&& U, + Matrix< std::complex >&& VT, + std::vector&& singular_values, + std::vector< std::complex >&& modifications, + std::vector&& modified_indices, + BlockFactor blockFactorType, + float mod_tol, + int64_t ib); + +// ---------------------------------------- +template +void getrf_addmod< Target::HostTask, std::complex >( + Matrix< std::complex >&& A, + Matrix< std::complex >&& U, + Matrix< std::complex >&& VT, + std::vector&& singular_values, + std::vector< std::complex >&& modifications, + std::vector&& modified_indices, + BlockFactor blockFactorType, + double mod_tol, + int64_t ib); + +} // namespace internal +} // namespace slate diff --git a/src/internal/internal_trsmA_addmod.cc b/src/internal/internal_trsmA_addmod.cc new file mode 100644 index 000000000..59c162479 --- /dev/null +++ b/src/internal/internal_trsmA_addmod.cc @@ -0,0 +1,346 @@ +// Copyright (c) 2022, University of Tennessee. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause +// This program is free software: you can redistribute it and/or modify it under +// the terms of the BSD 3-Clause license. See the accompanying LICENSE file. + +#include "slate/Matrix.hh" +#include "slate/types.hh" +#include "internal/internal.hh" +#include "internal/Tile_trsm_addmod.hh" +#include "slate/Tile_blas.hh" + +namespace slate { +namespace internal { + +//------------------------------------------------------------------------------ +/// Triangular solve matrix (multiple right-hand sides). +/// Dispatches to target implementations. +/// @ingroup trsm_internal +/// +template +void trsmA_addmod(Side side, Uplo uplo, scalar_t alpha, + Matrix&& A, + Matrix&& U, + Matrix&& VT, + std::vector>&& S, + Matrix&& B, + BlockFactor blockFactorType, + int64_t ib, int priority, Layout layout, int64_t queue_index) +{ + trsmA_addmod(internal::TargetType(), + side, uplo, alpha, A, U, VT, S, B, + blockFactorType, ib, priority, layout, queue_index); +} + +//------------------------------------------------------------------------------ +/// Triangular solve matrix (multiple right-hand sides). +/// Host OpenMP task implementation. +/// @ingroup trsm_internal +/// +template +void trsmA_addmod(internal::TargetType, + Side side, Uplo uplo, scalar_t alpha, + Matrix& A, + Matrix& U, + Matrix& VT, + std::vector>& S, + Matrix& B, + BlockFactor blockFactorType, + int64_t ib, int priority, Layout layout, int64_t queue_index) +{ + // CPU assumes column major + // todo: relax this assumption, by allowing Tile_blas.hh::trsm() + // to take layout param + // todo: optimize for the number of layout conversions, + // by watching 'layout' and 'B(i, j).layout()' + assert(layout == Layout::ColMajor); + assert(U.mt() == 1); + assert(VT.mt() == 1); + assert(A.mt() == 1); + assert(A.tileIsLocal(0,0) == U.tileIsLocal(0,0)); + assert(A.tileIsLocal(0,0) == VT.tileIsLocal(0,0)); + + if (A.tileIsLocal(0, 0)) { + A .tileGetForReading(0, 0, LayoutConvert(layout)); + if (uplo == Uplo::Lower) { + U.tileGetForReading(0, 0, LayoutConvert(layout)); + } + else { + VT.tileGetForReading(0, 0, LayoutConvert(layout)); + } + } + + #pragma omp taskgroup + if (side == Side::Right) { + assert(B.nt() == 1); + if (A.tileIsLocal(0, 0)) { + for (int64_t i = 0; i < B.mt(); ++i) { + #pragma omp task slate_omp_default_none \ + shared( A, U, VT, S, B ) \ + firstprivate(i, layout, side, uplo, ib) priority(priority) + { + B.tileGetForWriting(i, 0, LayoutConvert(layout)); + tile::trsm_addmod(blockFactorType, ib, side, uplo, alpha, + A(0, 0), U(0, 0), VT(0, 0), S, B(i, 0)); + } + } + } + } + else { + assert(B.mt() == 1); + if (A.tileIsLocal(0, 0)) { + for (int64_t j = 0; j < B.nt(); ++j) { + #pragma omp task slate_omp_default_none \ + shared( A, U, VT, S, B ) \ + firstprivate(j, layout, side, uplo, ib) priority(priority) + { + B.tileGetForWriting(0, j, LayoutConvert(layout)); + tile::trsm_addmod(blockFactorType, ib, side, uplo, alpha, + A(0, 0), U(0, 0), VT(0, 0), S, B(0, j)); + } + } + } + } +} + +//------------------------------------------------------------------------------ +/// Triangular solve matrix (multiple right-hand sides). +/// Host nested OpenMP implementation. +/// @ingroup trsm_internal +/// +template +void trsmA_addmod(internal::TargetType, + Side side, Uplo uplo, scalar_t alpha, + Matrix& A, + Matrix& U, + Matrix& VT, + std::vector>& S, + Matrix& B, + BlockFactor blockFactorType, + int64_t ib, int priority, Layout layout, int64_t queue_index) +{ + slate_not_implemented("Target::HostNest isn't yet supported."); +} + +//------------------------------------------------------------------------------ +/// Triangular solve matrix (multiple right-hand sides). +/// Host batched implementation. +/// @ingroup trsm_internal +/// +template +void trsmA_addmod(internal::TargetType, + Side side, Uplo uplo, scalar_t alpha, + Matrix& A, + Matrix& U, + Matrix& VT, + std::vector>& S, + Matrix& B, + BlockFactor blockFactorType, + int64_t ib, int priority, Layout layout, int64_t queue_index) +{ + slate_not_implemented("Target::HostBatch isn't yet supported."); +} + +//------------------------------------------------------------------------------ +/// Triangular solve matrix (multiple right-hand sides). +/// GPU device batched cuBLAS implementation. +/// @ingroup trsm_internal +/// +template +void trsmA_addmod(internal::TargetType, + Side side, Uplo uplo, scalar_t alpha, + Matrix& A, + Matrix& U, + Matrix& VT, + std::vector>& S, + Matrix& B, + BlockFactor blockFactorType, + int64_t ib, int priority, Layout layout, int64_t queue_index) +{ + slate_not_implemented("Target::Devices isn't yet supported."); +} + +//------------------------------------------------------------------------------ +// Explicit instantiations. +// ---------------------------------------- +template +void trsmA_addmod( + Side side, Uplo uplo, float alpha, + Matrix&& A, + Matrix&& U, + Matrix&& VT, + std::vector&& S, + Matrix&& B, + BlockFactor blockFactorType, + int64_t ib, int priority, Layout layout, int64_t queue_index); + +template +void trsmA_addmod( + Side side, Uplo uplo, float alpha, + Matrix&& A, + Matrix&& U, + Matrix&& VT, + std::vector&& S, + Matrix&& B, + BlockFactor blockFactorType, + int64_t ib, int priority, Layout layout, int64_t queue_index); + +template +void trsmA_addmod( + Side side, Uplo uplo, float alpha, + Matrix&& A, + Matrix&& U, + Matrix&& VT, + std::vector&& S, + Matrix&& B, + BlockFactor blockFactorType, + int64_t ib, int priority, Layout layout, int64_t queue_index); + +template +void trsmA_addmod( + Side side, Uplo uplo, float alpha, + Matrix&& A, + Matrix&& U, + Matrix&& VT, + std::vector&& S, + Matrix&& B, + BlockFactor blockFactorType, + int64_t ib, int priority, Layout layout, int64_t queue_index); + +// ---------------------------------------- +template +void trsmA_addmod( + Side side, Uplo uplo, double alpha, + Matrix&& A, + Matrix&& U, + Matrix&& VT, + std::vector&& S, + Matrix&& B, + BlockFactor blockFactorType, + int64_t ib, int priority, Layout layout, int64_t queue_index); + +template +void trsmA_addmod( + Side side, Uplo uplo, double alpha, + Matrix&& A, + Matrix&& U, + Matrix&& VT, + std::vector&& S, + Matrix&& B, + BlockFactor blockFactorType, + int64_t ib, int priority, Layout layout, int64_t queue_index); + +template +void trsmA_addmod( + Side side, Uplo uplo, double alpha, + Matrix&& A, + Matrix&& U, + Matrix&& VT, + std::vector&& S, + Matrix&& B, + BlockFactor blockFactorType, + int64_t ib, int priority, Layout layout, int64_t queue_index); + +template +void trsmA_addmod( + Side side, Uplo uplo, double alpha, + Matrix&& A, + Matrix&& U, + Matrix&& VT, + std::vector&& S, + Matrix&& B, + BlockFactor blockFactorType, + int64_t ib, int priority, Layout layout, int64_t queue_index); + +// ---------------------------------------- +template +void trsmA_addmod< Target::HostTask, std::complex >( + Side side, Uplo uplo, std::complex alpha, + Matrix>&& A, + Matrix>&& U, + Matrix>&& VT, + std::vector&& S, + Matrix>&& B, + BlockFactor blockFactorType, + int64_t ib, int priority, Layout layout, int64_t queue_index); + +template +void trsmA_addmod< Target::HostNest, std::complex >( + Side side, Uplo uplo, std::complex alpha, + Matrix>&& A, + Matrix>&& U, + Matrix>&& VT, + std::vector&& S, + Matrix>&& B, + BlockFactor blockFactorType, + int64_t ib, int priority, Layout layout, int64_t queue_index); + +template +void trsmA_addmod< Target::HostBatch, std::complex >( + Side side, Uplo uplo, std::complex alpha, + Matrix>&& A, + Matrix>&& U, + Matrix>&& VT, + std::vector&& S, + Matrix>&& B, + BlockFactor blockFactorType, + int64_t ib, int priority, Layout layout, int64_t queue_index); + +template +void trsmA_addmod< Target::Devices, std::complex >( + Side side, Uplo uplo, std::complex alpha, + Matrix>&& A, + Matrix>&& U, + Matrix>&& VT, + std::vector&& S, + Matrix>&& B, + BlockFactor blockFactorType, + int64_t ib, int priority, Layout layout, int64_t queue_index); + +// ---------------------------------------- +template +void trsmA_addmod< Target::HostTask, std::complex >( + Side side, Uplo uplo, std::complex alpha, + Matrix>&& A, + Matrix>&& U, + Matrix>&& VT, + std::vector&& S, + Matrix>&& B, + BlockFactor blockFactorType, + int64_t ib, int priority, Layout layout, int64_t queue_index); + +template +void trsmA_addmod< Target::HostNest, std::complex >( + Side side, Uplo uplo, std::complex alpha, + Matrix>&& A, + Matrix>&& U, + Matrix>&& VT, + std::vector&& S, + Matrix>&& B, + BlockFactor blockFactorType, + int64_t ib, int priority, Layout layout, int64_t queue_index); + +template +void trsmA_addmod< Target::HostBatch, std::complex >( + Side side, Uplo uplo, std::complex alpha, + Matrix>&& A, + Matrix>&& U, + Matrix>&& VT, + std::vector&& S, + Matrix>&& B, + BlockFactor blockFactorType, + int64_t ib, int priority, Layout layout, int64_t queue_index); + +template +void trsmA_addmod< Target::Devices, std::complex >( + Side side, Uplo uplo, std::complex alpha, + Matrix>&& A, + Matrix>&& U, + Matrix>&& VT, + std::vector&& S, + Matrix>&& B, + BlockFactor blockFactorType, + int64_t ib, int priority, Layout layout, int64_t queue_index); + +} // namespace internal +} // namespace slate diff --git a/src/internal/internal_trsm_addmod.cc b/src/internal/internal_trsm_addmod.cc new file mode 100644 index 000000000..d50a2e13f --- /dev/null +++ b/src/internal/internal_trsm_addmod.cc @@ -0,0 +1,614 @@ +// Copyright (c) 2022, University of Tennessee. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause +// This program is free software: you can redistribute it and/or modify it under +// the terms of the BSD 3-Clause license. See the accompanying LICENSE file. + +#include "slate/internal/device.hh" +#include "slate/Matrix.hh" +#include "slate/types.hh" +#include "internal/DevVector.hh" +#include "internal/internal.hh" +#include "internal/Tile_trsm_addmod.hh" +#include "slate/Tile_blas.hh" + +namespace slate { +namespace internal { + +//------------------------------------------------------------------------------ +/// Triangular solve matrix (multiple right-hand sides). +/// Dispatches to target implementations. +/// @ingroup trsm_internal +/// +template +void trsm_addmod(Side side, Uplo uplo, scalar_t alpha, + Matrix&& A, + Matrix&& U, + Matrix&& VT, + std::vector>&& S, + Matrix&& B, + BlockFactor blockFactorType, + int64_t ib, int priority, Layout layout, int64_t queue_index ) +{ + trsm_addmod(internal::TargetType(), + side, uplo, alpha, A, U, VT, S, B, + blockFactorType, ib, priority, layout, queue_index ); +} + +//------------------------------------------------------------------------------ +/// Triangular solve matrix (multiple right-hand sides). +/// Host OpenMP task implementation. +/// @ingroup trsm_internal +/// +template +void trsm_addmod(internal::TargetType, + Side side, Uplo uplo, scalar_t alpha, + Matrix& A, + Matrix& U, + Matrix& VT, + std::vector>& S, + Matrix& B, + BlockFactor blockFactorType, + int64_t ib, int priority, Layout layout, int64_t queue_index ) +{ + assert(A.mt() == 1); + assert(U.mt() == 1); + + if (B.numLocalTiles() > 0) { + A.tileGetForReading(0, 0, LayoutConvert(layout)); + if (uplo == Uplo::Lower) { + U.tileGetForReading(0, 0, LayoutConvert(layout)); + } + else { + if (blockFactorType != BlockFactor::QR) { + VT.tileGetForReading(0, 0, LayoutConvert(layout)); + } + } + } + + // TODO figure out if the workspaces can be shared/reused + #pragma omp taskgroup + if (side == Side::Right) { + assert(B.nt() == 1); + for (int64_t i = 0; i < B.mt(); ++i) { + if (B.tileIsLocal(i, 0)) { + #pragma omp task slate_omp_default_none \ + shared( A, U, VT, S, B ) \ + firstprivate(i, layout, side, uplo, ib) priority(priority) + { + B.tileGetForWriting(i, 0, LayoutConvert(layout)); + Tile U_VT; + if (uplo == Uplo::Lower) { + U_VT = U(0, 0); + } + else { + if (blockFactorType == BlockFactor::QR) { + U_VT = A(0, 0); // Not used + } + else { + U_VT = VT(0, 0); + } + } + tile::trsm_addmod(blockFactorType, ib, side, uplo, alpha, + A(0, 0), U_VT, U_VT, S, B(i, 0)); + } + } + } + } + else { + assert(B.mt() == 1); + for (int64_t j = 0; j < B.nt(); ++j) { + if (B.tileIsLocal(0, j)) { + #pragma omp task slate_omp_default_none \ + shared( A, U, VT, S, B ) \ + firstprivate(j, layout, side, uplo, ib) priority(priority) + { + B.tileGetForWriting(0, j, LayoutConvert(layout)); + Tile U_VT; + if (uplo == Uplo::Lower) { + U_VT = U(0, 0); + } + else { + if (blockFactorType == BlockFactor::QR) { + U_VT = A(0, 0); // Not used + } + else { + U_VT = VT(0, 0); + } + } + tile::trsm_addmod(blockFactorType, ib, side, uplo, alpha, + A(0, 0), U_VT, U_VT, S, B(0, j)); + } + } + } + } +} + +//------------------------------------------------------------------------------ +/// Triangular solve matrix (multiple right-hand sides). +/// Host nested OpenMP implementation. +/// @ingroup trsm_internal +/// +template +void trsm_addmod(internal::TargetType, + Side side, Uplo uplo, scalar_t alpha, + Matrix& A, + Matrix& U, + Matrix& VT, + std::vector>& S, + Matrix& B, + BlockFactor blockFactorType, + int64_t ib, int priority, Layout layout, int64_t queue_index ) +{ + slate_not_implemented("Target::HostNest isn't yet supported."); +} + +//------------------------------------------------------------------------------ +/// Triangular solve matrix (multiple right-hand sides). +/// Host batched implementation. +/// @ingroup trsm_internal +/// +template +void trsm_addmod(internal::TargetType, + Side side, Uplo uplo, scalar_t alpha, + Matrix& A, + Matrix& U, + Matrix& VT, + std::vector>& S, + Matrix& B, + BlockFactor blockFactorType, + int64_t ib, int priority, Layout layout, int64_t queue_index ) +{ + slate_not_implemented("Target::Device isn't yet supported."); +} + +//------------------------------------------------------------------------------ +/// Triangular solve matrix (multiple right-hand sides). +/// GPU device batched cuBLAS implementation. +/// @ingroup trsm_internal +/// +template +void trsm_addmod(internal::TargetType, + Side side, Uplo uplo, scalar_t alpha, + Matrix& A, + Matrix& U, + Matrix& VT, + std::vector>& S, + Matrix& B, + BlockFactor blockFactorType, + int64_t ib, int priority, Layout layout, int64_t queue_index ) +{ + + using blas::conj; + using ij_tuple = typename BaseMatrix::ij_tuple; + using real_t = blas::real_type; + + assert(B.num_devices() > 0); + assert(A.mt() == 1); + assert(B.uploPhysical() == Uplo::General); + assert(A.mt() == A.nt()); // square + assert(side == Side::Left ? A.mt() == B.mt() : A.mt() == B.nt()); + + assert(B.op() == Op::NoTrans); + assert(A.op() == Op::NoTrans); + + #pragma omp taskgroup + for (int device = 0; device < B.num_devices(); ++device) { + #pragma omp task shared(A, B) priority(priority) \ + firstprivate(device, side, layout, uplo, alpha, queue_index) + { + trace::Block trace_block("internal::trsm_addmod"); + std::set B_tiles_set; + if (side == Side::Right) { + for (int64_t i = 0; i < B.mt(); ++i) { + if (B.tileIsLocal(i, 0) + && device == B.tileDevice(i, 0)) + { + B_tiles_set.insert({i, 0}); + } + } + } + else { + for (int64_t j = 0; j < B.nt(); ++j) { + if (B.tileIsLocal(0, j) + && device == B.tileDevice(0, j)) + { + B_tiles_set.insert({0, j}); + } + } + } + + int64_t batch_size = B_tiles_set.size(); + if (batch_size > 0) { + blas::Queue* queue = B.compute_queue(device, queue_index); + assert(queue != nullptr); + + + A.tileGetForReading(0, 0, device, LayoutConvert(layout)); + if (uplo == Uplo::Lower) { + U.tileGetForReading(0, 0, device, LayoutConvert(layout)); + } + else { + if (blockFactorType != BlockFactor::QR) { + VT.tileGetForReading(0, 0, device, LayoutConvert(layout)); + } + } + B.tileGetForWriting(B_tiles_set, device, LayoutConvert(layout)); + + real_t* dS_ptr; + if (uplo == Uplo::Upper) { + dS_ptr = (real_t*)A.allocWorkspaceBuffer(device, S.size()); + blas::device_memcpy( dS_ptr, S.data(), S.size(), *queue ); + } + + // interior col or row + std::vector a_array0; + std::vector u_array0; + std::vector vt_array0; + std::vector s_array0; + std::vector b_array0; + a_array0.reserve( batch_size ); + b_array0.reserve( batch_size ); + if (uplo == Uplo::Lower) { + u_array0.reserve( batch_size ); + } + else { + if (blockFactorType != BlockFactor::QR) { + vt_array0.reserve( batch_size ); + } + s_array0.reserve( batch_size ); + } + + // bottom-right tile + // todo: replace batch trsm with plain trsm + std::vector a_array1; + std::vector u_array1; + std::vector vt_array1; + std::vector s_array1; + std::vector b_array1; + + int64_t lda0 = 0; + int64_t ldu0 = 0; + int64_t ldvt0 = 0; + int64_t ldb0 = 0; + int64_t lda1 = 0; + int64_t ldu1 = 0; + int64_t ldvt1 = 0; + int64_t ldb1 = 0; + + int64_t mb0 = B.tileMb(0); + int64_t nb0 = B.tileNb(0); + int64_t mb1 = B.tileMb(B.mt()-1); + int64_t nb1 = B.tileNb(B.nt()-1); + + if (side == Side::Right) { + for (int64_t i = 0; i < B.mt()-1; ++i) { + if (B.tileIsLocal(i, 0) + && device == B.tileDevice(i, 0)) + { + a_array0.push_back( A(0, 0, device).data() ); + b_array0.push_back( B(i, 0, device).data() ); + lda0 = A(0, 0, device).stride(); + ldb0 = B(i, 0, device).stride(); + if (uplo == Uplo::Lower) { + u_array0.push_back( U(0, 0, device).data() ); + ldu0 = U(0, 0, device).stride(); + } + else { + s_array0.push_back( dS_ptr ); + if (blockFactorType != BlockFactor::QR) { + vt_array0.push_back( VT(0, 0, device).data() ); + ldvt0 = VT(0, 0, device).stride(); + } + } + } + } + { + int64_t i = B.mt()-1; + if (B.tileIsLocal(i, 0) + && device == B.tileDevice(i, 0)) + { + a_array1.push_back( A(0, 0, device).data() ); + b_array1.push_back( B(i, 0, device).data() ); + lda1 = A(0, 0, device).stride(); + ldb1 = B(i, 0, device).stride(); + if (uplo == Uplo::Lower) { + u_array1.push_back( U(0, 0, device).data() ); + ldu1 = U(0, 0, device).stride(); + } + else { + s_array1.push_back( dS_ptr ); + if (blockFactorType != BlockFactor::QR) { + vt_array1.push_back( VT(0, 0, device).data() ); + ldvt1 = VT(0, 0, device).stride(); + } + } + } + } + } + else { + for (int64_t j = 0; j < B.nt()-1; ++j) { + if (B.tileIsLocal(0, j) + && device == B.tileDevice(0, j)) + { + a_array0.push_back( A(0, 0, device).data() ); + b_array0.push_back( B(0, j, device).data() ); + lda0 = A(0, 0, device).stride(); + ldb0 = B(0, j, device).stride(); + if (uplo == Uplo::Lower) { + u_array0.push_back( U(0, 0, device).data() ); + ldu0 = U(0, 0, device).stride(); + } + else { + s_array0.push_back( dS_ptr ); + if (blockFactorType != BlockFactor::QR) { + vt_array0.push_back( VT(0, 0, device).data() ); + ldvt0 = VT(0, 0, device).stride(); + } + } + } + } + { + int64_t j = B.nt()-1; + if (B.tileIsLocal(0, j) + && device == B.tileDevice(0, j)) + { + a_array1.push_back( A(0, 0, device).data() ); + b_array1.push_back( B(0, j, device).data() ); + lda1 = A(0, 0, device).stride(); + ldb1 = B(0, j, device).stride(); + if (uplo == Uplo::Lower) { + u_array1.push_back( U(0, 0, device).data() ); + ldu1 = U(0, 0, device).stride(); + } + else { + s_array1.push_back( dS_ptr ); + if (blockFactorType != BlockFactor::QR) { + vt_array1.push_back( VT(0, 0, device).data() ); + ldvt1 = VT(0, 0, device).stride(); + } + } + } + } + } + + { + std::vector workarray0 (a_array0.size()); + std::vector workarray1 (a_array1.size()); + for (size_t i = 0; i < a_array0.size(); ++i) { + workarray0[i] = A.allocWorkspaceBuffer(device, mb0*nb0); + } + for (size_t i = 0; i < a_array1.size(); ++i) { + workarray1[i] = A.allocWorkspaceBuffer(device, mb1*nb1); + } + + if (a_array0.size() > 0) { + device::batch_trsm_addmod( + blockFactorType, + layout, side, uplo, mb0, nb0, ib, + alpha, a_array0, lda0, + u_array0, ldu0, + vt_array0, ldvt0, + s_array0, + b_array0, ldb0, + workarray0, + a_array0.size(), *queue); + } + + if (a_array1.size() > 0) { + device::batch_trsm_addmod( + blockFactorType, + layout, side, uplo, mb1, nb1, ib, + alpha, a_array1, lda1, + u_array1, ldu1, + vt_array1, ldvt1, + s_array1, + b_array1, ldb1, + workarray1, + a_array1.size(), *queue); + } + + queue->sync(); + + // return workspace memory + for (size_t i = 0; i < a_array0.size(); ++i) { + A.freeWorkspaceBuffer(device, workarray0[i]); + } + for (size_t i = 0; i < a_array1.size(); ++i) { + A.freeWorkspaceBuffer(device, workarray1[i]); + } + } + + // return workspace memory + if (uplo == Uplo::Upper) { + A.freeWorkspaceBuffer( device, (scalar_t*)dS_ptr ); + } + } + } + } + // end omp taskgroup +} + +//------------------------------------------------------------------------------ +// Explicit instantiations. +// ---------------------------------------- +template +void trsm_addmod( + Side side, Uplo uplo, float alpha, + Matrix&& A, + Matrix&& U, + Matrix&& VT, + std::vector&& S, + Matrix&& B, + BlockFactor blockFactorType, + int64_t ib, int priority, Layout layout, int64_t queue_index ); + +template +void trsm_addmod( + Side side, Uplo uplo, float alpha, + Matrix&& A, + Matrix&& U, + Matrix&& VT, + std::vector&& S, + Matrix&& B, + BlockFactor blockFactorType, + int64_t ib, int priority, Layout layout, int64_t queue_index ); + +template +void trsm_addmod( + Side side, Uplo uplo, float alpha, + Matrix&& A, + Matrix&& U, + Matrix&& VT, + std::vector&& S, + Matrix&& B, + BlockFactor blockFactorType, + int64_t ib, int priority, Layout layout, int64_t queue_index ); + +template +void trsm_addmod( + Side side, Uplo uplo, float alpha, + Matrix&& A, + Matrix&& U, + Matrix&& VT, + std::vector&& S, + Matrix&& B, + BlockFactor blockFactorType, + int64_t ib, int priority, Layout layout, int64_t queue_index ); + +// ---------------------------------------- +template +void trsm_addmod( + Side side, Uplo uplo, double alpha, + Matrix&& A, + Matrix&& U, + Matrix&& VT, + std::vector&& S, + Matrix&& B, + BlockFactor blockFactorType, + int64_t ib, int priority, Layout layout, int64_t queue_index ); + +template +void trsm_addmod( + Side side, Uplo uplo, double alpha, + Matrix&& A, + Matrix&& U, + Matrix&& VT, + std::vector&& S, + Matrix&& B, + BlockFactor blockFactorType, + int64_t ib, int priority, Layout layout, int64_t queue_index ); + +template +void trsm_addmod( + Side side, Uplo uplo, double alpha, + Matrix&& A, + Matrix&& U, + Matrix&& VT, + std::vector&& S, + Matrix&& B, + BlockFactor blockFactorType, + int64_t ib, int priority, Layout layout, int64_t queue_index ); + +template +void trsm_addmod( + Side side, Uplo uplo, double alpha, + Matrix&& A, + Matrix&& U, + Matrix&& VT, + std::vector&& S, + Matrix&& B, + BlockFactor blockFactorType, + int64_t ib, int priority, Layout layout, int64_t queue_index ); + +// ---------------------------------------- +template +void trsm_addmod< Target::HostTask, std::complex >( + Side side, Uplo uplo, std::complex alpha, + Matrix>&& A, + Matrix>&& U, + Matrix>&& VT, + std::vector&& S, + Matrix>&& B, + BlockFactor blockFactorType, + int64_t ib, int priority, Layout layout, int64_t queue_index ); + +template +void trsm_addmod< Target::HostNest, std::complex >( + Side side, Uplo uplo, std::complex alpha, + Matrix>&& A, + Matrix>&& U, + Matrix>&& VT, + std::vector&& S, + Matrix>&& B, + BlockFactor blockFactorType, + int64_t ib, int priority, Layout layout, int64_t queue_index ); + +template +void trsm_addmod< Target::HostBatch, std::complex >( + Side side, Uplo uplo, std::complex alpha, + Matrix>&& A, + Matrix>&& U, + Matrix>&& VT, + std::vector&& S, + Matrix>&& B, + BlockFactor blockFactorType, + int64_t ib, int priority, Layout layout, int64_t queue_index ); + +template +void trsm_addmod< Target::Devices, std::complex >( + Side side, Uplo uplo, std::complex alpha, + Matrix>&& A, + Matrix>&& U, + Matrix>&& VT, + std::vector&& S, + Matrix>&& B, + BlockFactor blockFactorType, + int64_t ib, int priority, Layout layout, int64_t queue_index ); + +// ---------------------------------------- +template +void trsm_addmod< Target::HostTask, std::complex >( + Side side, Uplo uplo, std::complex alpha, + Matrix>&& A, + Matrix>&& U, + Matrix>&& VT, + std::vector&& S, + Matrix>&& B, + BlockFactor blockFactorType, + int64_t ib, int priority, Layout layout, int64_t queue_index ); + +template +void trsm_addmod< Target::HostNest, std::complex >( + Side side, Uplo uplo, std::complex alpha, + Matrix>&& A, + Matrix>&& U, + Matrix>&& VT, + std::vector&& S, + Matrix>&& B, + BlockFactor blockFactorType, + int64_t ib, int priority, Layout layout, int64_t queue_index ); + +template +void trsm_addmod< Target::HostBatch, std::complex >( + Side side, Uplo uplo, std::complex alpha, + Matrix>&& A, + Matrix>&& U, + Matrix>&& VT, + std::vector&& S, + Matrix>&& B, + BlockFactor blockFactorType, + int64_t ib, int priority, Layout layout, int64_t queue_index ); + +template +void trsm_addmod< Target::Devices, std::complex >( + Side side, Uplo uplo, std::complex alpha, + Matrix>&& A, + Matrix>&& U, + Matrix>&& VT, + std::vector&& S, + Matrix>&& B, + BlockFactor blockFactorType, + int64_t ib, int priority, Layout layout, int64_t queue_index ); + +} // namespace internal +} // namespace slate diff --git a/src/trsm.cc b/src/trsm.cc index 5447d48ea..d54941972 100644 --- a/src/trsm.cc +++ b/src/trsm.cc @@ -75,7 +75,7 @@ void trsm(blas::Side side, opts, Option::MethodTrsm, MethodTrsm::Auto ); if (method == MethodTrsm::Auto) - method = MethodTrsm::select_algo( A, B, opts ); + method = MethodTrsm::select_algo( A, B, side, opts ); switch (method) { case MethodTrsm::TrsmA: diff --git a/src/trsmA.cc b/src/trsmA.cc index c2b9e6f3c..0645b9fbc 100644 --- a/src/trsmA.cc +++ b/src/trsmA.cc @@ -133,7 +133,7 @@ void trsmA( Matrix& B, Options const& opts ) { - Target target = get_option(opts, Option::Target, Target::HostTask); + Target target = Target::HostTask; //get_option(opts, Option::Target, Target::HostTask); switch (target) { case Target::Host: diff --git a/src/trsmA_addmod.cc b/src/trsmA_addmod.cc new file mode 100644 index 000000000..588059728 --- /dev/null +++ b/src/trsmA_addmod.cc @@ -0,0 +1,145 @@ +// Copyright (c) 2017-2022, University of Tennessee. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause +// This program is free software: you can redistribute it and/or modify it under +// the terms of the BSD 3-Clause license. See the accompanying LICENSE file. + +#include "slate/slate.hh" +#include "work/work.hh" + +namespace slate { + +// specialization namespace differentiates, e.g., +// internal::trsmA_addmod from internal::specialization::trsmA_addmod +namespace internal { +namespace specialization { + +//------------------------------------------------------------------------------ +/// @internal +/// Distributed parallel triangular matrix solve. +/// Generic implementation for any target. +/// @ingroup trsm_impl +/// +template +void trsmA_addmod(slate::internal::TargetType, + Side side, Uplo uplo, + scalar_t alpha, AddModFactors& W, + Matrix& B, + int64_t lookahead) +{ + if (target == Target::Devices) { + const int64_t batch_size_zero = 0; + const int64_t num_arrays_two = 2; // Number of kernels without lookahead + // Allocate batch arrays = number of kernels without + // lookahead + lookahead + // number of kernels without lookahead = 2 + // (internal::gemm & internal::trsm) + // TODO + // whereas internal::gemm with lookahead will be executed as many as + // lookaheads, thus + // internal::gemm with lookahead needs batch arrays equal to the + // number of lookaheads + // and the batch_arrays_index starts from + // the number of kernels without lookahead, and then incremented by 1 + // for every execution for the internal::gemm with lookahead + B.allocateBatchArrays(batch_size_zero, num_arrays_two); + B.reserveDeviceWorkspace(); + } + + // OpenMP needs pointer types, but vectors are exception safe + std::vector row_vector(W.A.nt()); + uint8_t* row = row_vector.data(); + + // set min number for omp nested active parallel regions + slate::OmpSetMaxActiveLevels set_active_levels( MinOmpActiveLevels ); + + #pragma omp parallel + #pragma omp master + { + #pragma omp task + { + work::trsmA_addmod(side, uplo, alpha, W, B, row, lookahead); + B.tileUpdateAllOrigin(); + } + } + B.releaseWorkspace(); +} + +} // namespace specialization +} // namespace internal + +//------------------------------------------------------------------------------ +/// Version with target as template parameter. +/// @ingroup trsm_impl +/// +template +void trsmA_addmod(blas::Side side, blas::Uplo uplo, + scalar_t alpha, AddModFactors& W, + Matrix& B, + Options const& opts) +{ + int64_t lookahead = get_option(opts, Option::Lookahead, 1); + + internal::specialization::trsmA_addmod(internal::TargetType(), + side, uplo, + alpha, W, + B, + lookahead); +} + +// TODO docs +template +void trsmA_addmod(blas::Side side, blas::Uplo uplo, + scalar_t alpha, AddModFactors& W, + Matrix& B, + Options const& opts) +{ + Target target = Target::HostTask; //get_option(opts, Option::Target, Target::HostTask); + + switch (target) { + case Target::Host: + case Target::HostTask: + trsmA_addmod(side, uplo, alpha, W, B, opts); + break; + case Target::HostNest: + trsmA_addmod(side, uplo, alpha, W, B, opts); + break; + case Target::HostBatch: + trsmA_addmod(side, uplo, alpha, W, B, opts); + break; + case Target::Devices: + trsmA_addmod(side, uplo, alpha, W, B, opts); + break; + } +} + +//------------------------------------------------------------------------------ +// Explicit instantiations. +template +void trsmA_addmod( + blas::Side side, blas::Uplo uplo, + float alpha, AddModFactors& W, + Matrix& B, + Options const& opts); + +template +void trsmA_addmod( + blas::Side side, blas::Uplo uplo, + double alpha, AddModFactors& W, + Matrix& B, + Options const& opts); + +template +void trsmA_addmod< std::complex >( + blas::Side side, blas::Uplo uplo, + std::complex alpha, AddModFactors< std::complex >& W, + Matrix< std::complex >& B, + Options const& opts); + +template +void trsmA_addmod< std::complex >( + blas::Side side, blas::Uplo uplo, + std::complex alpha, AddModFactors< std::complex >& W, + Matrix< std::complex >& B, + Options const& opts); + +} // namespace slate diff --git a/src/trsmB_addmod.cc b/src/trsmB_addmod.cc new file mode 100644 index 000000000..3ca69e9fe --- /dev/null +++ b/src/trsmB_addmod.cc @@ -0,0 +1,136 @@ +// Copyright (c) 2017-2022, University of Tennessee. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause +// This program is free software: you can redistribute it and/or modify it under +// the terms of the BSD 3-Clause license. See the accompanying LICENSE file. + +#include "slate/slate.hh" +#include "work/work.hh" + +namespace slate { + +namespace impl { + +//------------------------------------------------------------------------------ +/// @internal +/// Distributed parallel triangular matrix solve. +/// Generic implementation for any target. +/// @ingroup trsm_specialization +/// +template +void trsmB_addmod( slate::internal::TargetType, + Side side, blas::Uplo uplo, + scalar_t alpha, AddModFactors& W, + Matrix& B, + Options const& opts) +{ + if (target == Target::Devices) { + int64_t lookahead = get_option( opts, Option::Lookahead, 1 ); + + const int64_t batch_size_zero = 0; + // Allocate batch arrays = number of kernels without + // lookahead + lookahead + // number of kernels without lookahead = 2 + // (internal::gemm & internal::trsm) + // TODO + // whereas internal::gemm with lookahead will be executed as many as + // lookaheads, thus + // internal::gemm with lookahead needs batch arrays equal to the + // number of lookaheads + // and the batch_arrays_index starts from + // the number of kernels without lookahead, and then incremented by 1 + // for every execution for the internal::gemm with lookahead + + // Number of device queues (num_queues): + // 1) trsm ( 1 ) + // 2) gemm for trailing matrix update ( 1 ) + // 3) lookahead number of gemm's ( lookahead ) + const int num_queues = 2 + lookahead; + B.allocateBatchArrays( batch_size_zero, num_queues ); + B.reserveDeviceWorkspace(); + } + + // OpenMP needs pointer types, but vectors are exception safe + std::vector row_vector(W.A.nt()); + uint8_t* row = row_vector.data(); + + // set min number for omp nested active parallel regions + slate::OmpSetMaxActiveLevels set_active_levels( MinOmpActiveLevels ); + + #pragma omp parallel + #pragma omp master + { + #pragma omp task + { + work::trsm_addmod( side, uplo, alpha, W, B, row, opts ); + B.tileUpdateAllOrigin(); + } + } + B.releaseWorkspace(); +} + +} // namespace impl + +//TODO docs +template +void trsmB_addmod( blas::Side side, blas::Uplo uplo, + scalar_t alpha, AddModFactors& W, + Matrix& B, + Options const& opts) +{ + Target target = get_option( opts, Option::Target, Target::HostTask ); + + switch (target) { + case Target::Host: + case Target::HostTask: + impl::trsmB_addmod( + internal::TargetType(), + side, uplo, alpha, W, B, opts ); + break; + case Target::HostNest: + impl::trsmB_addmod( + internal::TargetType(), + side, uplo, alpha, W, B, opts ); + break; + case Target::HostBatch: + impl::trsmB_addmod( + internal::TargetType(), + side, uplo, alpha, W, B, opts ); + break; + case Target::Devices: + impl::trsmB_addmod( + internal::TargetType(), + side, uplo, alpha, W, B, opts ); + } +} + +//------------------------------------------------------------------------------ +// Explicit instantiations. +template +void trsmB_addmod( + blas::Side side, blas::Uplo uplo, + float alpha, AddModFactors& W, + Matrix& B, + Options const& opts); + +template +void trsmB_addmod( + blas::Side side, blas::Uplo uplo, + double alpha, AddModFactors& W, + Matrix& B, + Options const& opts); + +template +void trsmB_addmod< std::complex >( + blas::Side side, blas::Uplo uplo, + std::complex alpha, AddModFactors< std::complex >& W, + Matrix< std::complex >& B, + Options const& opts); + +template +void trsmB_addmod< std::complex >( + blas::Side side, blas::Uplo uplo, + std::complex alpha, AddModFactors< std::complex >& W, + Matrix< std::complex >& B, + Options const& opts); + +} // namespace slate diff --git a/src/trsm_addmod.cc b/src/trsm_addmod.cc new file mode 100644 index 000000000..22382af6b --- /dev/null +++ b/src/trsm_addmod.cc @@ -0,0 +1,63 @@ +// Copyright (c) 2017-2022, University of Tennessee. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause +// This program is free software: you can redistribute it and/or modify it under +// the terms of the BSD 3-Clause license. See the accompanying LICENSE file. + +#include "slate/slate.hh" + +namespace slate { + +// TODO docs +template +void trsm_addmod(blas::Side side, blas::Uplo uplo, + scalar_t alpha, AddModFactors& W, + Matrix& B, + Options const& opts) +{ + Method method = get_option( + opts, Option::MethodTrsm, MethodTrsm::Auto ); + + if (method == MethodTrsm::Auto) + method = MethodTrsm::select_algo( W.A, B, side, opts ); + + switch (method) { + case MethodTrsm::TrsmA: + trsmA_addmod( side, uplo, alpha, W, B, opts ); + break; + case MethodTrsm::TrsmB: + trsmB_addmod( side, uplo, alpha, W, B, opts ); + break; + } +} + +//------------------------------------------------------------------------------ +// Explicit instantiations. +template +void trsm_addmod( + blas::Side side, blas::Uplo uplo, + float alpha, AddModFactors& A, + Matrix& B, + Options const& opts); + +template +void trsm_addmod( + blas::Side side, blas::Uplo uplo, + double alpha, AddModFactors& A, + Matrix& B, + Options const& opts); + +template +void trsm_addmod< std::complex >( + blas::Side side, blas::Uplo uplo, + std::complex alpha, AddModFactors< std::complex >& A, + Matrix< std::complex >& B, + Options const& opts); + +template +void trsm_addmod< std::complex >( + blas::Side side, blas::Uplo uplo, + std::complex alpha, AddModFactors< std::complex >& A, + Matrix< std::complex >& B, + Options const& opts); + +} // namespace slate diff --git a/src/work/work.hh b/src/work/work.hh index 90a15038e..1a2450505 100644 --- a/src/work/work.hh +++ b/src/work/work.hh @@ -42,6 +42,22 @@ void trsmA(Side side, scalar_t alpha, TriangularMatrix A, Matrix B, uint8_t* row, Options const& opts); +//----------------------------------------- +// trsm_addmod() +template +void trsm_addmod(Side side, Uplo uplo, + scalar_t alpha, AddModFactors W, + Matrix B, + uint8_t* row, Options const& opts); + +//----------------------------------------- +// trsmA_addmod() +template +void trsmA_addmod(Side side, Uplo uplo, + scalar_t alpha, AddModFactors W, + Matrix B, + uint8_t* row, int64_t lookahead=1); + } // namespace work } // namespace slate diff --git a/src/work/work_trsmA_addmod.cc b/src/work/work_trsmA_addmod.cc new file mode 100644 index 000000000..f81e26126 --- /dev/null +++ b/src/work/work_trsmA_addmod.cc @@ -0,0 +1,889 @@ +// Copyright (c) 2017-2022, University of Tennessee. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause +// This program is free software: you can redistribute it and/or modify it under +// the terms of the BSD 3-Clause license. See the accompanying LICENSE file. + +#include "slate/slate.hh" +#include "internal/internal.hh" +#include "work/work.hh" + +namespace slate { +namespace work { + +// TODO update the description of the function +//------------------------------------------------------------------------------ +/// AddMod factor solve matrix (multiple right-hand sides). +/// +/// @tparam target +/// One of HostTask, HostNest, HostBatch, Devices. +/// +/// @tparam scalar_t +/// One of float, double, std::complex, std::complex. +//------------------------------------------------------------------------------ +/// @param[in] side +/// Whether W appears on the left or on the right of X: +/// - Side::Left: solve $W X = \alpha B$ +/// - Side::Right: solve $X W = \alpha B$ +/// +/// @param[in] alpha +/// The scalar alpha. +/// +/// @param[in] W +/// - If side = left, the m-by-m triangular matrix W; +/// - if side = right, the n-by-n triangular matrix W. +/// +/// @param[in,out] B +/// On entry, the m-by-n matrix B. +/// On exit, overwritten by the result X. +/// +/// @param[in] row +/// A raw pointer to a dummy vector data. The dummy vector is used for +/// OpenMP dependencies tracking, not based on the actual data. Entries +/// in the dummy vector represent each row of matrix $B$. The size of +/// row should be number of block columns of matrix $A$. +/// +/// @param[in] lookahead +/// Number of blocks to overlap communication and computation. +/// lookahead >= 0. Default 1. +/// +/// @ingroup trsm_internal +/// +template +void trsmA_addmod(Side side, Uplo uplo, + scalar_t alpha, AddModFactors W, + Matrix B, + uint8_t* row, int64_t lookahead) +{ + using blas::conj; + using BcastList = typename Matrix::BcastList; + using std::real; + using std::imag; + + int64_t ib = W.block_size; + auto& A = W.A; + auto& U = W.U_factors; + auto& VT = W.VT_factors; + auto& S = W.singular_values; + auto blockFactorType = W.factorType; + + // Assumes column major + const Layout layout = Layout::ColMajor; + + // Because of the asymmetric between upper and lower, we can't transpose both sides + + // Check sizes + if (side == Side::Left) { + assert(W.A.mt() == B.mt()); + assert(W.A.nt() == B.mt()); + } + else { + assert(W.A.mt() == B.nt()); + assert(W.A.nt() == B.nt()); + } + + int64_t mt = B.mt(); + int64_t nt = B.nt(); + + const int priority_one = 1; + const int priority_zero = 0; + + // Requires 2 queues + if (target == Target::Devices) + assert(B.numComputeQueues() >= 2); + //const int64_t queue_0 = 0; + const int64_t queue_1 = 1; + + const scalar_t one = 1.0; + + if (side == Side::Left) { + if (uplo == Uplo::Lower) { + // ---------------------------------------- + // Lower, Left case + // Forward sweep + for (int64_t k = 0; k < mt; ++k) { + + // panel (Akk tile) + #pragma omp task depend(inout:row[k]) priority(1) + { + // Scale the RHS in order to be consistent with the upper case + if (k == 0 && alpha != one) { + for (int64_t i = 0; i < mt; ++i) { + for (int64_t j = 0; j < nt; ++j) { + if (B.tileIsLocal(i, j)) { + tile::scale( alpha, B(i, j) ); + } + } + } + } + + // Create the local B tiles where A(k,k) is located + if (A.tileIsLocal(k, k)) { + for (int64_t j = 0; j < nt; ++j) { + if (! B.tileIsLocal(k, j) && ! B.tileExists(k, j)) { + B.tileInsert(k, j); + B.at(k, j).set(0, 0); + } + } + } + + // Gather B(k,:) to rank owning diagonal block A(k,k) + using ReduceList = typename Matrix::ReduceList; + ReduceList reduce_list_B; + for (int64_t j = 0; j < nt; ++j) { + reduce_list_B.push_back({k, j, + A.sub(k, k, k, k), + { A.sub(k, k, 0, k), + B.sub(k, k, j, j ) + } + }); + } + B.template listReduce(reduce_list_B, layout); + + if (A.tileIsLocal(k, k)) { + // solve A(k, k) B(k, :) = alpha B(k, :) + internal::trsmA_addmod( + Side::Left, Uplo::Lower, + one, A.sub(k, k, k, k), + U.sub(k, k, k, k), + VT.sub(k, k, k, k), + std::move(S[k]), + B.sub(k, k, 0, nt-1), + blockFactorType, + ib, priority_one, layout, queue_1); + } + + // Send the solution back to where it belongs + // TODO : could be part of the bcast of the solution, + // but not working now + if (A.tileIsLocal(k, k)) { + for (int64_t j = 0; j < nt; ++j) { + int dest = B.tileRank(k, j); + B.tileSend(k, j, dest); + } + } + else { + const int root = A.tileRank(k, k); + + for (int64_t j = 0; j < nt; ++j) { + if (B.tileIsLocal(k, j)) { + B.tileRecv(k, j, root, layout); + } + } + } + + for (int64_t j = 0; j < nt; ++j) + if (B.tileExists(k, j) && ! B.tileIsLocal(k, j)) + B.tileErase(k, j); + + // Bcast the result of the solve, B(k,:) to + // ranks owning block row A(k + 1 : mt, k) + BcastList bcast_list_upd_B; + for (int64_t j = 0; j < nt; ++j) { + bcast_list_upd_B.push_back( + {k, j, { A.sub(k + 1, mt - 1, k, k), }}); + } + B.template listBcast(bcast_list_upd_B, layout); + } + + // lookahead update, B(k+1:k+la, :) -= A(k+1:k+la, k) B(k, :) + for (int64_t i = k+1; i < k+1+lookahead && i < mt; ++i) { + #pragma omp task depend(in:row[k]) \ + depend(inout:row[i]) priority(1) + { + if (A.tileIsLocal(i, k)) { + for (int64_t j = 0; j < nt; ++j) { + if (! B.tileIsLocal(i, j) + && ! B.tileExists(i, j)) + { + B.tileInsert(i, j); + B.at(i, j).set(0, 0); + } + } + } + // TODO: execute lookahead on devices + internal::gemmA( + -one, A.sub(i, i, k, k), + B.sub(k, k, 0, nt-1), + one, B.sub(i, i, 0, nt-1), + layout, priority_one); + } + } + + // trailing update, + // B(k+1+la:mt-1, :) -= A(k+1+la:mt-1, k) B(k, :) + // Updates rows k+1+la to mt-1, but two depends are sufficient: + // depend on k+1+la is all that is needed in next iteration; + // depend on mt-1 daisy chains all the trailing updates. + if (k+1+lookahead < mt) { + #pragma omp task depend(in:row[k]) \ + depend(inout:row[k+1+lookahead]) \ + depend(inout:row[mt-1]) + { + for (int64_t i = k+1+lookahead; i < mt; ++i) { + if (A.tileIsLocal(i, k)) { + for (int64_t j = 0; j < nt; ++j) { + if (! B.tileIsLocal(i, j) + && ! B.tileExists(i, j)) + { + B.tileInsert(i, j); + B.at(i, j).set(0, 0); + } + } + } + } + + //internal::gemmA( + internal::gemmA( + -one, A.sub(k+1+lookahead, mt-1, k, k), + B.sub(k, k, 0, nt-1), + one, B.sub(k+1+lookahead, mt-1, 0, nt-1), + layout, priority_zero); //, queue_0); + } + } + + // Erase remote or workspace tiles. + #pragma omp task depend(inout:row[k]) + { + auto A_col_k = A.sub( k, mt-1, k, k ); + A_col_k.releaseRemoteWorkspace(); + A_col_k.releaseLocalWorkspace(); + + auto B_row_k = B.sub( k, k, 0, nt-1 ); + B_row_k.releaseRemoteWorkspace(); + // Copy back modifications to tiles in the B panel + // before they are erased. + B_row_k.tileUpdateAllOrigin(); + B_row_k.releaseLocalWorkspace(); + } + } + } + else { + // ---------------------------------------- + // Upper, Left case + // Backward sweep + for (int64_t k = mt-1; k >= 0; --k) { + + // panel (Akk tile) + #pragma omp task depend(inout:row[k]) priority(1) + { + // Scale the RHS to handle the alpha issue since B is moved + // around instead of the A as in trsm + if (k == mt - 1 && alpha != one) { + for (int64_t i = 0; i < mt; ++i) { + for (int64_t j = 0; j < nt; ++j) { + if (B.tileIsLocal(i, j)) { + tile::scale( alpha, B(i, j) ); + } + } + } + } + + // Create the local B tiles where A(k,k) is located + if (A.tileIsLocal(k, k)) { + for (int64_t j = 0; j < nt; ++j) { + if (! B.tileIsLocal(k, j) && ! B.tileExists(k, j)) { + B.tileInsert(k, j); + B.at(k, j).set(0, 0); // Might not needed if alph is set correctly + } + } + } + + // Gather B(k,:) to rank owning diagonal block A(k,k) + using ReduceList = typename Matrix::ReduceList; + ReduceList reduce_list_B; + for (int64_t j = 0; j < nt; ++j) { + reduce_list_B.push_back({k, j, + A.sub(k, k, k, k), + { A.sub(k, k, k, mt - 1), + B.sub(k, k, j, j ) + } + }); + } + B.template listReduce(reduce_list_B, layout); + + if (A.tileIsLocal(k, k)) { + // solve A(k, k) B(k, :) = alpha B(k, :) + internal::trsmA_addmod( + Side::Left, Uplo::Upper, + one, A.sub(k, k, k, k), + U.sub(k, k, k, k), + VT.sub(k, k, k, k), + std::move(S[k]), + B.sub(k, k, 0, nt-1), + blockFactorType, + ib, priority_one, layout, queue_1); + } + + // Send the solution back to where it belongs + // TODO : could be part of the bcast of the solution, + // but not working now + if (A.tileIsLocal(k, k)) { + for (int64_t j = 0; j < nt; ++j) { + int dest = B.tileRank(k, j); + B.tileSend(k, j, dest); + } + } + else { + const int root = A.tileRank(k, k); + + for (int64_t j = 0; j < nt; ++j) { + if (B.tileIsLocal(k, j)) { + B.tileRecv(k, j, root, layout); + } + } + } + + for (int64_t j = 0; j < nt; ++j) + if (B.tileExists(k, j) && ! B.tileIsLocal(k, j)) + B.tileErase(k, j); + + // Bcast the result of the solve, B(k,:) to + // ranks owning block row A(k + 1 : mt, k) + BcastList bcast_list_upd_B; + for (int64_t j = 0; j < nt; ++j) { + bcast_list_upd_B.push_back( + {k, j, { A.sub(0, k - 1, k, k), }}); + } + B.template listBcast(bcast_list_upd_B, layout); + } + + // lookahead update, B(k-la:k-1, :) -= A(k-la:k-1, k) B(k, :) + for (int64_t i = k-1; i > k-1-lookahead && i >= 0; --i) { + #pragma omp task depend(in:row[k]) \ + depend(inout:row[i]) priority(1) + { + if (A.tileIsLocal(i, k)) { + for (int64_t j = 0; j < nt; ++j) { + if (! B.tileIsLocal(i, j) + && ! B.tileExists(i, j)) + { + B.tileInsert(i, j); + B.at(i, j).set(0, 0); + } + } + } + // TODO: execute lookahead on devices + internal::gemmA( + -one, A.sub(i, i, k, k), + B.sub(k, k, 0, nt-1), + one, B.sub(i, i, 0, nt-1), + layout, priority_one); + } + } + + // trailing update, + // B(0:k-1-la, :) -= A(0:k-1-la, k) B(k, :) + // Updates rows 0 to k-1-la, but two depends are sufficient: + // depend on k-1-la is all that is needed in next iteration; + // depend on 0 daisy chains all the trailing updates. + if (k-1-lookahead >= 0) { + #pragma omp task depend(in:row[k]) \ + depend(inout:row[k-1-lookahead]) \ + depend(inout:row[0]) + { + for (int64_t i = 0; i < k - lookahead; ++i) { + if (A.tileIsLocal(i, k)) { + for (int64_t j = 0; j < nt; ++j) { + if (! B.tileIsLocal(i, j) + && ! B.tileExists(i, j)) + { + B.tileInsert(i, j); + B.at(i, j).set(0, 0); + } + } + } + } + + //internal::gemm( + internal::gemmA( + -one, A.sub(0, k-1-lookahead, k, k), + B.sub(k, k, 0, nt-1), + one, B.sub(0, k-1-lookahead, 0, nt-1), + layout, priority_zero); //, queue_0); + } + } + + // Erase remote or workspace tiles. + #pragma omp task depend(inout:row[k]) + { + auto A_col_k = A.sub( 0, k, k, k ); + A_col_k.releaseRemoteWorkspace(); + A_col_k.releaseLocalWorkspace(); + + auto B_row_k = B.sub( k, k, 0, nt-1 ); + B_row_k.releaseRemoteWorkspace(); + // Copy back modifications to tiles in the B panel + // before they are erased. + B_row_k.tileUpdateAllOrigin(); + B_row_k.releaseLocalWorkspace(); + } + } + } + } + else { + if (uplo == Uplo::Lower) { + // ---------------------------------------- + // Lower, Right case + // Backward sweep + for (int64_t k = nt-1; k >= 0; --k) { + + // panel (Akk tile) + #pragma omp task depend(inout:row[k]) priority(1) + { + // Scale the RHS to handle the alpha issue since B is moved + // around instead of the A as in trsm + if (k == nt - 1 && alpha != one) { + for (int64_t i = 0; i < mt; ++i) { + for (int64_t j = 0; j < nt; ++j) { + if (B.tileIsLocal(i, j)) { + tile::scale( alpha, B(i, j) ); + } + } + } + } + + // Create the local B tiles where A(k,k) is located + if (A.tileIsLocal(k, k)) { + for (int64_t i = 0; i < mt; ++i) { + if (! B.tileIsLocal(i, k) && ! B.tileExists(i, k)) { + B.tileInsert(i, k); + B.at(i, k).set(0, 0); // Might not needed if alph is set correctly + } + } + } + + // Gather B(:,k) to rank owning diagonal block A(k,k) + using ReduceList = typename Matrix::ReduceList; + ReduceList reduce_list_B; + for (int64_t i = 0; i < mt; ++i) { + reduce_list_B.push_back({i, k, + A.sub(k, k, k, k), + { A.sub(k, nt - 1, k, k), + B.sub(i, i, k, k) + } + }); + } + B.template listReduce(reduce_list_B, layout); + + if (A.tileIsLocal(k, k)) { + // solve A(k, k) B(k, :) = alpha B(k, :) + internal::trsmA_addmod( + Side::Right, Uplo::Lower, + one, A.sub(k, k, k, k), + U.sub(k, k, k, k), + VT.sub(k, k, k, k), + std::move(S[k]), + B.sub(0, mt-1, k, k), + blockFactorType, + ib, priority_one, layout, queue_1); + } + + // Send the solution back to where it belongs + // TODO : could be part of the bcast of the solution, + // but not working now + if (A.tileIsLocal(k, k)) { + for (int64_t i = 0; i < mt; ++i) { + int dest = B.tileRank(i, k); + B.tileSend(i, k, dest); + } + } + else { + const int root = A.tileRank(k, k); + + for (int64_t i = 0; i < mt; ++i) { + if (B.tileIsLocal(i, k)) { + B.tileRecv(i, k, root, layout); + } + } + } + + for (int64_t i = 0; i < mt; ++i) + if (B.tileExists(i, k) && ! B.tileIsLocal(i, k)) + B.tileErase(i, k); + + // Bcast the result of the solve, B(:,k) to + // ranks owning block column A(k, k + 1 : nt) + BcastList bcast_list_upd_B; + for (int64_t i = 0; i < mt; ++i) { + bcast_list_upd_B.push_back( + {i, k, { A.sub(k, k, 0, k - 1), }}); + } + B.template listBcast(bcast_list_upd_B, layout); + } + + // lookahead update, B(:, k-la:k-1) -= B(:, k) A(k, k-la:k-1) + for (int64_t j = k-1; j > k-1-lookahead && j >= 0; --j) { + #pragma omp task depend(in:row[k]) \ + depend(inout:row[j]) priority(1) + { + if (A.tileIsLocal(k, j)) { + for (int64_t i = 0; i < mt; ++i) { + if (! B.tileIsLocal(i, j) + && ! B.tileExists(i, j)) + { + B.tileInsert(i, j); + B.at(i, j).set(0, 0); + } + } + } + // TODO: execute lookahead on devices + //internal::gemmB( + // -one, B.sub(0, mt-1, k, k), + // A.sub(k, k, j, j), + // one, B.sub(0, mt-1, j, j), + // layout, priority_one); + internal::gemmA( + -one, conj_transpose(A.sub(k, k, j, j)), + conj_transpose(B.sub(0, mt-1, k, k)), + one, conj_transpose(B.sub(0, mt-1, j, j)), + layout, priority_one); + } + } + + // trailing update, + // B(:, 0:k-1-la) -= B(:, k) A(k, 0:k-1-la) + // Updates columns 0 to k-1-la, but two depends are sufficient: + // depend on k-1-la is all that is needed in next iteration; + // depend on 0 daisy chains all the trailing updates. + if (k-1-lookahead >= 0) { + #pragma omp task depend(in:row[k]) \ + depend(inout:row[k-1-lookahead]) \ + depend(inout:row[0]) + { + for (int64_t j = 0; j < k - lookahead; ++j) { + if (A.tileIsLocal(k, j)) { + for (int64_t i = 0; i < mt; ++i) { + if (! B.tileIsLocal(i, j) + && ! B.tileExists(i, j)) + { + B.tileInsert(i, j); + B.at(i, j).set(0, 0); + } + } + } + } + + //internal::gemmB( + // -one, B.sub(0, mt-1, k, k), + // A.sub(k, k, 0, k-1-lookahead), + // one, B.sub(0, mt-1, 0, k-1-lookahead), + // layout, priority_zero); //, queue_0); + internal::gemmA( + -one, conj_transpose(A.sub(k, k, 0, k-1-lookahead)), + conj_transpose(B.sub(0, mt-1, k, k)), + one, conj_transpose(B.sub(0, mt-1, 0, k-1-lookahead)), + layout, priority_zero); //, queue_0); + } + } + + // Erase remote or workspace tiles. + #pragma omp task depend(inout:row[k]) + { + auto A_row_k = A.sub( k, k, 0, k ); + A_row_k.releaseRemoteWorkspace(); + A_row_k.releaseLocalWorkspace(); + + auto B_col_k = B.sub( 0, mt-1, k, k ); + B_col_k.releaseRemoteWorkspace(); + // Copy back modifications to tiles in the B panel + // before they are erased. + B_col_k.tileUpdateAllOrigin(); + B_col_k.releaseLocalWorkspace(); + } + } + } + else { + // ---------------------------------------- + // Upper, Right case + // Forward sweep + for (int64_t k = 0; k < nt; ++k) { + + // panel (Akk tile) + #pragma omp task depend(inout:row[k]) priority(1) + { + // Scale the RHS in order to be consistent with the upper case + if (k == 0 && alpha != one) { + for (int64_t i = 0; i < mt; ++i) { + for (int64_t j = 0; j < nt; ++j) { + if (B.tileIsLocal(i, j)) { + tile::scale( alpha, B(i, j) ); + } + } + } + } + + // Create the local B tiles where A(k,k) is located + if (A.tileIsLocal(k, k)) { + for (int64_t i = 0; i < mt; ++i) { + if (! B.tileIsLocal(i, k) && ! B.tileExists(i, k)) { + B.tileInsert(i, k); + B.at(i, k).set(0, 0); + } + } + } + + // Gather B(:,k) to rank owning diagonal block A(k,k) + using ReduceList = typename Matrix::ReduceList; + ReduceList reduce_list_B; + for (int64_t i = 0; i < mt; ++i) { + reduce_list_B.push_back({i, k, + A.sub(k, k, k, k), + { A.sub(0, k, k, k), + B.sub(i, i, k, k ) + } + }); + } + B.template listReduce(reduce_list_B, layout); + + if (A.tileIsLocal(k, k)) { + // solve B(:, k) A(k, k) = alpha B(:, k) + internal::trsmA_addmod( + Side::Right, Uplo::Upper, + one, A.sub(k, k, k, k), + U.sub(k, k, k, k), + VT.sub(k, k, k, k), + std::move(S[k]), + B.sub(0, mt-1, k, k), + blockFactorType, + ib, priority_one, layout, queue_1); + } + + // Send the solution back to where it belongs + // TODO : could be part of the bcast of the solution, + // but not working now + if (A.tileIsLocal(k, k)) { + for (int64_t i = 0; i < mt; ++i) { + int dest = B.tileRank(i, k); + B.tileSend(i, k, dest); + } + } + else { + const int root = A.tileRank(k, k); + + for (int64_t i = 0; i < mt; ++i) { + if (B.tileIsLocal(i, k)) { + B.tileRecv(i, k, root, layout); + } + } + } + + for (int64_t i = 0; i < mt; ++i) + if (B.tileExists(i, k) && ! B.tileIsLocal(i, k)) + B.tileErase(i, k); + + // Bcast the result of the solve, B(:,k) to + // ranks owning block column A(k, k + 1 : nt) + BcastList bcast_list_upd_B; + for (int64_t i = 0; i < mt; ++i) { + bcast_list_upd_B.push_back( + {i, k, { A.sub(k, k, k + 1, nt - 1), }}); + } + B.template listBcast(bcast_list_upd_B, layout); + } + + // lookahead update, B(:, k+1:k+la) -= B(:, k) A(k, k+1:k+la) + for (int64_t j = k+1; j < k+1+lookahead && j < nt; ++j) { + #pragma omp task depend(in:row[k]) \ + depend(inout:row[j]) priority(1) + { + if (A.tileIsLocal(k, j)) { + for (int64_t i = 0; i < mt; ++i) { + if (! B.tileIsLocal(i, j) + && ! B.tileExists(i, j)) + { + B.tileInsert(i, j); + B.at(i, j).set(0, 0); + } + } + } + // TODO: execute lookahead on devices + //internal::gemmB( + // -one, B.sub(0, mt-1, k, k), + // A.sub(k, k, j, j), + // one, B.sub(0, mt-1, j, j), + // layout, priority_one); + internal::gemmA( + -one, conj_transpose(A.sub(k, k, j, j)), + conj_transpose(B.sub(0, mt-1, k, k)), + one, conj_transpose(B.sub(0, mt-1, j, j)), + layout, priority_one); + } + } + + // trailing update, + // B(:, k+1+la:nt-1) -= B(:, k) A(k, k+1+la:nt-1) + // Updates columns k+1+la to nt-1, but two depends are sufficient: + // depend on k+1+la is all that is needed in next iteration; + // depend on nt-1 daisy chains all the trailing updates. + if (k+1+lookahead < nt) { + #pragma omp task depend(in:row[k]) \ + depend(inout:row[k+1+lookahead]) \ + depend(inout:row[nt-1]) + { + for (int64_t j = k+1+lookahead; j < nt; ++j) { + if (A.tileIsLocal(k, j)) { + for (int64_t i = 0; i < mt; ++i) { + if (! B.tileIsLocal(i, j) + && ! B.tileExists(i, j)) + { + B.tileInsert(i, j); + B.at(i, j).set(0, 0); + } + } + } + } + + //internal::gemmB( + // -one, B.sub(0, mt-1, k, k), + // A.sub(k, k, k+1+lookahead, nt-1), + // one, B.sub(0, mt-1, k+1+lookahead, nt-1), + // layout, priority_zero); //, queue_0); + internal::gemmA( + -one, conj_transpose(A.sub(k, k, k+1+lookahead, nt-1)), + conj_transpose(B.sub(0, mt-1, k, k)), + one, conj_transpose(B.sub(0, mt-1, k+1+lookahead, nt-1)), + layout, priority_zero); //, queue_0); + } + } + + // Erase remote or workspace tiles. + #pragma omp task depend(inout:row[k]) + { + auto A_row_k = A.sub( k, k, k, nt-1 ); + A_row_k.releaseRemoteWorkspace(); + A_row_k.releaseLocalWorkspace(); + + auto B_col_k = B.sub( 0, mt-1, k, k ); + B_col_k.releaseRemoteWorkspace(); + // Copy back modifications to tiles in the B panel + // before they are erased. + B_col_k.tileUpdateAllOrigin(); + B_col_k.releaseLocalWorkspace(); + } + } + } + } + + #pragma omp taskwait +} + +//------------------------------------------------------------------------------ +// Explicit instantiations. +// ---------------------------------------- +template +void trsmA_addmod( + Side side, Uplo uplo, + float alpha, AddModFactors A, + Matrix B, + uint8_t* row, int64_t lookahead); + +template +void trsmA_addmod( + Side side, Uplo uplo, + float alpha, AddModFactors W, + Matrix B, + uint8_t* row, int64_t lookahead); + +template +void trsmA_addmod( + Side side, Uplo uplo, + float alpha, AddModFactors W, + Matrix B, + uint8_t* row, int64_t lookahead); + +template +void trsmA_addmod( + Side side, Uplo uplo, + float alpha, AddModFactors W, + Matrix B, + uint8_t* row, int64_t lookahead); + +// ---------------------------------------- +template +void trsmA_addmod( + Side side, Uplo uplo, + double alpha, AddModFactors W, + Matrix B, + uint8_t* row, int64_t lookahead); + +template +void trsmA_addmod( + Side side, Uplo uplo, + double alpha, AddModFactors W, + Matrix B, + uint8_t* row, int64_t lookahead); + +template +void trsmA_addmod( + Side side, Uplo uplo, + double alpha, AddModFactors W, + Matrix B, + uint8_t* row, int64_t lookahead); + +template +void trsmA_addmod( + Side side, Uplo uplo, + double alpha, AddModFactors W, + Matrix B, + uint8_t* row, int64_t lookahead); + +// ---------------------------------------- +template +void trsmA_addmod>( + Side side, Uplo uplo, + std::complex alpha, AddModFactors> W, + Matrix> B, + uint8_t* row, int64_t lookahead); + +template +void trsmA_addmod>( + Side side, Uplo uplo, + std::complex alpha, AddModFactors> W, + Matrix> B, + uint8_t* row, int64_t lookahead); + +template +void trsmA_addmod>( + Side side, Uplo uplo, + std::complex alpha, AddModFactors> W, + Matrix> B, + uint8_t* row, int64_t lookahead); + +template +void trsmA_addmod>( + Side side, Uplo uplo, + std::complex alpha, AddModFactors> W, + Matrix> B, + uint8_t* row, int64_t lookahead); + +// ---------------------------------------- +template +void trsmA_addmod>( + Side side, Uplo uplo, + std::complex alpha, AddModFactors> W, + Matrix> B, + uint8_t* row, int64_t lookahead); + +template +void trsmA_addmod>( + Side side, Uplo uplo, + std::complex alpha, AddModFactors> W, + Matrix> B, + uint8_t* row, int64_t lookahead); + +template +void trsmA_addmod>( + Side side, Uplo uplo, + std::complex alpha, AddModFactors> W, + Matrix> B, + uint8_t* row, int64_t lookahead); + +template +void trsmA_addmod>( + Side side, Uplo uplo, + std::complex alpha, AddModFactors> W, + Matrix> B, + uint8_t* row, int64_t lookahead); + +} // namespace work +} // namespace slate diff --git a/src/work/work_trsm_addmod.cc b/src/work/work_trsm_addmod.cc new file mode 100644 index 000000000..967fb68e9 --- /dev/null +++ b/src/work/work_trsm_addmod.cc @@ -0,0 +1,569 @@ +// Copyright (c) 2017-2022, University of Tennessee. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause +// This program is free software: you can redistribute it and/or modify it under +// the terms of the BSD 3-Clause license. See the accompanying LICENSE file. + +#include "slate/slate.hh" +#include "internal/internal.hh" +#include "work/work.hh" + +namespace slate { +namespace work { + +//------------------------------------------------------------------------------ +/// AddMod factors solve matrix (multiple right-hand sides). +/// +/// @tparam target +/// One of HostTask, HostNest, HostBatch, Devices. +/// +/// @tparam scalar_t +/// One of float, double, std::complex, std::complex. +//------------------------------------------------------------------------------ +/// @param[in] side +/// Whether W appears on the left or on the right of X: +/// - Side::Left: solve $W X = \alpha B$ +/// - Side::Right: solve $X W = \alpha B$ +/// +/// @param[in] alpha +/// The scalar alpha. +/// +/// @param[in] W +/// - If side = left, the m-by-m AddMod factor matrix W; +/// - if side = right, the n-by-n AddMod factor matrix W. +/// +/// @param[in,out] B +/// On entry, the m-by-n matrix B. +/// On exit, overwritten by the result X. +/// +/// @param[in] row +/// A raw pointer to a dummy vector data. The dummy vector is used for +/// OpenMP dependencies tracking, not based on the actual data. Entries +/// in the dummy vector represent each row of matrix $B$. The size of +/// row should be number of block columns of matrix $A$. +/// +/// @param[in] lookahead +/// Number of blocks to overlap communication and computation. +/// lookahead >= 0. Default 1. +/// +/// @ingroup trsm_addmod_internal +/// +template +void trsm_addmod(Side side, Uplo uplo, + scalar_t alpha, AddModFactors W, + Matrix B, + uint8_t* row, Options const& opts) +{ + using blas::conj; + using BcastList = typename Matrix::BcastList; + + const scalar_t one = 1.0; + + int64_t lookahead = get_option( opts, Option::Lookahead, 1 ); + + int64_t ib = W.block_size; + auto& A = W.A; + auto& U = W.U_factors; + auto& VT = W.VT_factors; + auto& S = W.singular_values; + auto blockFactorType = W.factorType; + + // Assumes column major + const Layout layout = Layout::ColMajor; + + // Because of the asymmetric between upper and lower, we can't transpose both sides + + // Check sizes + if (side == Side::Left) { + assert(A.mt() == B.mt()); + assert(A.nt() == B.mt()); + } + else { + assert(A.mt() == B.nt()); + assert(A.nt() == B.nt()); + } + + int64_t mt = B.mt(); + int64_t nt = B.nt(); + + const int priority_one = 1; + const int priority_zero = 0; + + const int64_t queue_0 = 0; + const int64_t queue_1 = 1; + + if (side == Side::Left) { + if (uplo == Uplo::Lower) { + // ---------------------------------------- + // Lower, Left case + // Forward sweep + for (int64_t k = 0; k < mt; ++k) { + scalar_t alph = k == 0 ? alpha : one; + + // panel (Akk tile) + #pragma omp task depend(inout:row[k]) priority(1) + { + // send A(k, k) to ranks owning block row B(k, :) + A .template tileBcast(k, k, B.sub(k, k, 0, nt-1), layout); + U .template tileBcast(k, k, B.sub(k, k, 0, nt-1), layout); + + // solve A(k, k) B(k, :) = alpha B(k, :) + internal::trsm_addmod( + Side::Left, Uplo::Lower, + alph, A.sub(k, k, k, k), + U.sub(k, k, k, k), + VT.sub(k, k, k, k), + std::move(S[k]), + B.sub(k, k, 0, nt-1), + blockFactorType, ib, priority_one, layout, queue_1 ); + + // send A(i=k+1:mt-1, k) to ranks owning block row B(i, :) + BcastList bcast_list_A; + for (int64_t i = k+1; i < mt; ++i) + bcast_list_A.push_back({i, k, {B.sub(i, i, 0, nt-1)}}); + A.template listBcast(bcast_list_A, layout); + + // send B(k, j=0:nt-1) to ranks owning + // block col B(k+1:mt-1, j) + BcastList bcast_list_B; + for (int64_t j = 0; j < nt; ++j) { + bcast_list_B.push_back( + {k, j, {B.sub(k+1, mt-1, j, j)}}); + } + B.template listBcast(bcast_list_B, layout); + } + + // lookahead update, B(k+1:k+la, :) -= A(k+1:k+la, k) B(k, :) + for (int64_t i = k+1; i < k+1+lookahead && i < mt; ++i) { + #pragma omp task depend(in:row[k]) \ + depend(inout:row[i]) priority(1) + { + internal::gemm( + -one, A.sub(i, i, k, k), + B.sub(k, k, 0, nt-1), + alph, B.sub(i, i, 0, nt-1), + layout, priority_one, i-k+1 ); + } + } + + // trailing update, + // B(k+1+la:mt-1, :) -= A(k+1+la:mt-1, k) B(k, :) + // Updates rows k+1+la to mt-1, but two depends are sufficient: + // depend on k+1+la is all that is needed in next iteration; + // depend on mt-1 daisy chains all the trailing updates. + if (k+1+lookahead < mt) { + #pragma omp task depend(in:row[k]) \ + depend(inout:row[k+1+lookahead]) \ + depend(inout:row[mt-1]) + { + internal::gemm( + -one, A.sub(k+1+lookahead, mt-1, k, k), + B.sub(k, k, 0, nt-1), + alph, B.sub(k+1+lookahead, mt-1, 0, nt-1), + layout, priority_zero, queue_0 ); + } + } + + // Erase remote or workspace tiles. + #pragma omp task depend(inout:row[k]) + { + auto A_panel = A.sub(k, mt-1, k, k); + A_panel.releaseRemoteWorkspace(); + A_panel.releaseLocalWorkspace(); + + auto B_panel = B.sub(k, k, 0, nt-1); + B_panel.releaseRemoteWorkspace(); + + // Copy back modifications to tiles in the B panel + // before they are erased. + B_panel.tileUpdateAllOrigin(); + B_panel.releaseLocalWorkspace(); + } + } + } + else { + // ---------------------------------------- + // Upper, Left case + // Backward sweep + for (int64_t k = mt-1; k >= 0; --k) { + scalar_t alph = k == (mt-1) ? alpha : one; + + // panel (Akk tile) + #pragma omp task depend(inout:row[k]) priority(1) + { + // send A(k, k) to ranks owning block row B(k, :) + A .template tileBcast(k, k, B.sub(k, k, 0, nt-1), layout); + if (blockFactorType != BlockFactor::QR) { + VT.template tileBcast(k, k, B.sub(k, k, 0, nt-1), layout); + } + + // solve A(k, k) B(k, :) = alpha B(k, :) + internal::trsm_addmod( + Side::Left, Uplo::Upper, + alph, A.sub(k, k, k, k), + U.sub(k, k, k, k), + VT.sub(k, k, k, k), + std::move(S[k]), + B.sub(k, k, 0, nt-1), + blockFactorType, ib, priority_one, layout, queue_1 ); + + // send A(i=0:k-1, k) to ranks owning block row B(i, :) + BcastList bcast_list_A; + for (int64_t i = 0; i < k; ++i) + bcast_list_A.push_back({i, k, {B.sub(i, i, 0, nt-1)}}); + A.template listBcast(bcast_list_A, layout); + + // send B(k, j=0:nt-1) to ranks owning block col B(0:k-1, j) + BcastList bcast_list_B; + for (int64_t j = 0; j < nt; ++j) + bcast_list_B.push_back({k, j, {B.sub(0, k-1, j, j)}}); + B.template listBcast(bcast_list_B, layout); + } + + // lookahead update, B(k-la:k-1, :) -= A(k-la:k-1, k) B(k, :) + for (int64_t i = k-1; i > k-1-lookahead && i >= 0; --i) { + #pragma omp task depend(in:row[k]) \ + depend(inout:row[i]) priority(1) + { + internal::gemm( + -one, A.sub(i, i, k, k), + B.sub(k, k, 0, nt-1), + alph, B.sub(i, i, 0, nt-1), + layout, priority_one, i-k+lookahead+2 ); + } + } + + // trailing update, B(0:k-1-la, :) -= A(0:k-1-la, k) B(k, :) + // Updates rows 0 to k-1-la, but two depends are sufficient: + // depend on k-1-la is all that is needed in next iteration; + // depend on 0 daisy chains all the trailing updates. + if (k-1-lookahead >= 0) { + #pragma omp task depend(in:row[k]) \ + depend(inout:row[k-1-lookahead]) \ + depend(inout:row[0]) + { + internal::gemm( + -one, A.sub(0, k-1-lookahead, k, k), + B.sub(k, k, 0, nt-1), + alph, B.sub(0, k-1-lookahead, 0, nt-1), + layout, priority_zero, queue_0 ); + } + } + + // Erase remote or workspace tiles. + #pragma omp task depend(inout:row[k]) + { + auto A_panel = A.sub(0, k, k, k); + A_panel.releaseRemoteWorkspace(); + A_panel.releaseLocalWorkspace(); + + auto B_panel = B.sub(k, k, 0, nt-1); + B_panel.releaseRemoteWorkspace(); + + // Copy back modifications to tiles in the B panel + // before they are erased. + B_panel.tileUpdateAllOrigin(); + B_panel.releaseLocalWorkspace(); + } + } + } + } + else { + if (uplo == Uplo::Lower) { + // ---------------------------------------- + // Lower, Right case + // Backward sweep + for (int64_t k = nt-1; k >= 0; --k) { + scalar_t alph = k == (nt-1) ? alpha : one; + + // panel (Akk tile) + #pragma omp task depend(inout:row[k]) priority(1) + { + // send A(k, k) to ranks owning block column B(:, k) + A .template tileBcast(k, k, B.sub(0, mt-1, k, k), layout); + U .template tileBcast(k, k, B.sub(0, mt-1, k, k), layout); + + // solve B(k, :) A(k, k) = alpha B(k, :) + internal::trsm_addmod( + Side::Right, Uplo::Lower, + alph, A.sub(k, k, k, k), + U.sub(k, k, k, k), + VT.sub(k, k, k, k), + std::move(S[k]), + B.sub(0, mt-1, k, k), + blockFactorType, ib, priority_one, layout, queue_1 ); + + // send A(k, j=0:k-1) to ranks owning block column B(:, j) + BcastList bcast_list_A; + for (int64_t j = 0; j < k; ++j) + bcast_list_A.push_back({k, j, {B.sub(0, mt-1, j, j)}}); + A.template listBcast(bcast_list_A, layout); + + // send B(i=0:nt-1, k) to ranks owning block col B(i, 0:k-1) + BcastList bcast_list_B; + for (int64_t i = 0; i < mt; ++i) + bcast_list_B.push_back({i, k, {B.sub(i, i, 0, k-1)}}); + B.template listBcast(bcast_list_B, layout); + } + + // lookahead update, B(:, k-la:k-1) -= B(:, k) A(k, k-la:k-1) + for (int64_t j = k-1; j > k-1-lookahead && j >= 0; --j) { + #pragma omp task depend(in:row[k]) \ + depend(inout:row[j]) priority(1) + { + internal::gemm( + -one, B.sub(0, mt-1, k, k), + A.sub(k, k, j, j), + alph, B.sub(0, mt-1, j, j), + layout, priority_one, j-k+lookahead+2 ); + } + } + + // trailing update, B(:, 0:k-1-la) -= B(:, k) A(k, 0:k-1-la) + // Updates columns 0 to k-1-la, but two depends are sufficient: + // depend on k-1-la is all that is needed in next iteration; + // depend on 0 daisy chains all the trailing updates. + if (k-1-lookahead >= 0) { + #pragma omp task depend(in:row[k]) \ + depend(inout:row[k-1-lookahead]) \ + depend(inout:row[0]) + { + internal::gemm( + -one, B.sub(0, mt-1, k, k), + A.sub(k, k, 0, k-1-lookahead), + alph, B.sub(0, mt-1, 0, k-1-lookahead), + layout, priority_zero, queue_0 ); + } + } + + // Erase remote or workspace tiles. + #pragma omp task depend(inout:row[k]) + { + auto A_panel = A.sub(k, k, 0, k); + A_panel.releaseRemoteWorkspace(); + A_panel.releaseLocalWorkspace(); + + auto B_panel = B.sub(0, mt-1, k, k); + B_panel.releaseRemoteWorkspace(); + + // Copy back modifications to tiles in the B panel + // before they are erased. + B_panel.tileUpdateAllOrigin(); + B_panel.releaseLocalWorkspace(); + } + } + } + else { + // ---------------------------------------- + // Upper, Right case + // Forward sweep + for (int64_t k = 0; k < nt; ++k) { + scalar_t alph = k == 0 ? alpha : one; + + // panel (Akk tile) + #pragma omp task depend(inout:row[k]) priority(1) + { + // send A(k, k) to ranks owning block column B(:, k) + A .template tileBcast(k, k, B.sub(0, mt-1, k, k), layout); + if (blockFactorType != BlockFactor::QR) { + VT.template tileBcast(k, k, B.sub(0, mt-1, k, k), layout); + } + + // solve B(:, k) A(k, k) = alpha B(:, k) + internal::trsm_addmod( + Side::Right, Uplo::Upper, + alph, A.sub(k, k, k, k), + U.sub(k, k, k, k), + VT.sub(k, k, k, k), + std::move(S[k]), + B.sub(0, mt-1, k, k), + blockFactorType, ib, priority_one, layout, queue_1 ); + + // send A(k, j=k+1:nt-1) to ranks owning block column B(:, j) + BcastList bcast_list_A; + for (int64_t j = k+1; j < nt; ++j) + bcast_list_A.push_back({k, j, {B.sub(0, mt-1, j, j)}}); + A.template listBcast(bcast_list_A, layout); + + // send B(i=0:mt-1, k) to ranks owning + // block row B(i, k+1:nt-1) + BcastList bcast_list_B; + for (int64_t i = 0; i < mt; ++i) { + bcast_list_B.push_back( + {i, k, {B.sub(i, i, k+1, nt-1)}}); + } + B.template listBcast(bcast_list_B, layout); + } + + // lookahead update, B(:, k+1:k+la) -= B(:, k) A(k, k+1:k+la) + for (int64_t j = k+1; j < k+1+lookahead && j < nt; ++j) { + #pragma omp task depend(in:row[k]) \ + depend(inout:row[j]) priority(1) + { + internal::gemm( + -one, B.sub(0, mt-1, k, k), + A.sub(k, k, j, j), + alph, B.sub(0, mt-1, j, j), + layout, priority_one, j-k+1 ); + } + } + + // trailing update, + // B(:, k+1+la:nt-1) -= B(:, k) A(k, k+1+la:nt-1) + // Updates rows k+1+la to mt-1, but two depends are sufficient: + // depend on k+1+la is all that is needed in next iteration; + // depend on mt-1 daisy chains all the trailing updates. + if (k+1+lookahead < nt) { + #pragma omp task depend(in:row[k]) \ + depend(inout:row[k+1+lookahead]) \ + depend(inout:row[nt-1]) + { + internal::gemm( + -one, B.sub(0, mt-1, k, k), + A.sub(k, k, k+1+lookahead, nt-1), + alph, B.sub(0, mt-1, k+1+lookahead, nt-1), + layout, priority_zero, queue_0 ); + } + } + + // Erase remote or workspace tiles. + #pragma omp task depend(inout:row[k]) + { + auto A_panel = A.sub(k, k, k, nt-1); + A_panel.releaseRemoteWorkspace(); + A_panel.releaseLocalWorkspace(); + + auto B_panel = B.sub(0, mt-1, k, k); + B_panel.releaseRemoteWorkspace(); + + // Copy back modifications to tiles in the B panel + // before they are erased. + B_panel.tileUpdateAllOrigin(); + B_panel.releaseLocalWorkspace(); + } + } + } + } + + #pragma omp taskwait +} + +//------------------------------------------------------------------------------ +// Explicit instantiations. +// ---------------------------------------- +template +void trsm_addmod( + Side side, Uplo uplo, + float alpha, AddModFactors A, + Matrix B, + uint8_t* row, Options const& opts); + +template +void trsm_addmod( + Side side, Uplo uplo, + float alpha, AddModFactors A, + Matrix B, + uint8_t* row, Options const& opts); + +template +void trsm_addmod( + Side side, Uplo uplo, + float alpha, AddModFactors A, + Matrix B, + uint8_t* row, Options const& opts); + +template +void trsm_addmod( + Side side, Uplo uplo, + float alpha, AddModFactors A, + Matrix B, + uint8_t* row, Options const& opts); + +// ---------------------------------------- +template +void trsm_addmod( + Side side, Uplo uplo, + double alpha, AddModFactors A, + Matrix B, + uint8_t* row, Options const& opts); + +template +void trsm_addmod( + Side side, Uplo uplo, + double alpha, AddModFactors A, + Matrix B, + uint8_t* row, Options const& opts); + +template +void trsm_addmod( + Side side, Uplo uplo, + double alpha, AddModFactors A, + Matrix B, + uint8_t* row, Options const& opts); + +template +void trsm_addmod( + Side side, Uplo uplo, + double alpha, AddModFactors A, + Matrix B, + uint8_t* row, Options const& opts); + +// ---------------------------------------- +template +void trsm_addmod>( + Side side, Uplo uplo, + std::complex alpha, AddModFactors> A, + Matrix> B, + uint8_t* row, Options const& opts); + +template +void trsm_addmod>( + Side side, Uplo uplo, + std::complex alpha, AddModFactors> A, + Matrix> B, + uint8_t* row, Options const& opts); + +template +void trsm_addmod>( + Side side, Uplo uplo, + std::complex alpha, AddModFactors> A, + Matrix> B, + uint8_t* row, Options const& opts); + +template +void trsm_addmod>( + Side side, Uplo uplo, + std::complex alpha, AddModFactors> A, + Matrix> B, + uint8_t* row, Options const& opts); + +// ---------------------------------------- +template +void trsm_addmod>( + Side side, Uplo uplo, + std::complex alpha, AddModFactors> A, + Matrix> B, + uint8_t* row, Options const& opts); + +template +void trsm_addmod>( + Side side, Uplo uplo, + std::complex alpha, AddModFactors> A, + Matrix> B, + uint8_t* row, Options const& opts); + +template +void trsm_addmod>( + Side side, Uplo uplo, + std::complex alpha, AddModFactors> A, + Matrix> B, + uint8_t* row, Options const& opts); + +template +void trsm_addmod>( + Side side, Uplo uplo, + std::complex alpha, AddModFactors> A, + Matrix> B, + uint8_t* row, Options const& opts); + +} // namespace work +} // namespace slate diff --git a/test/test.cc b/test/test.cc index 2d46f782f..3a98f4b07 100644 --- a/test/test.cc +++ b/test/test.cc @@ -116,18 +116,22 @@ std::vector< testsweeper::routines_t > routines = { { "gesv_mixed", test_gesv, Section::gesv }, { "gesv_mixed_gmres", test_gesv, Section::gesv }, { "gesv_rbt", test_gesv, Section::gesv }, + { "gesv_addmod", test_gesv, Section::gesv }, + { "gesv_addmod_ir", test_gesv, Section::gesv }, { "gbsv", test_gbsv, Section::gesv }, { "", nullptr, Section::newline }, { "getrf", test_gesv, Section::gesv }, { "getrf_nopiv", test_gesv, Section::gesv }, { "getrf_tntpiv", test_gesv, Section::gesv }, + { "getrf_addmod", test_gesv, Section::gesv }, { "gbtrf", test_gbsv, Section::gesv }, { "", nullptr, Section::newline }, { "getrs", test_gesv, Section::gesv }, { "getrs_nopiv", test_gesv, Section::gesv }, { "getrs_tntpiv", test_gesv, Section::gesv }, + { "getrs_addmod", test_gesv, Section::gesv }, { "gbtrs", test_gbsv, Section::gesv }, { "", nullptr, Section::newline }, @@ -415,6 +419,10 @@ Params::Params(): itermax ("itermax", 7, ParamType::List, 30, -1, 1000000, "Maximum number of iterations for refinement"), fallback ("fallback",0, ParamType::List, 'y', "ny", "If refinement fails, fallback to a robust solver"), depth ("depth", 5, ParamType::List, 2, 0, 1000, "Number of butterflies to apply"), + add_tol ("addtol", 8, 2, ParamType::List, -1e-8, -inf, inf, "threshold for additive changes"), + woodbury ("woodbury", + 8, ParamType::List, 'y', "ny", "wheather to apply the Woodbury formula"), + blockfactor("bf", 4, ParamType::List, slate::BlockFactor::SVD, slate::str2blockfactor, slate::blockfactor2str, "Block Factor for Addmod"), // ----- output parameters // min, max are ignored diff --git a/test/test.hh b/test/test.hh index b0d48e0f3..7da4c9d11 100644 --- a/test/test.hh +++ b/test/test.hh @@ -143,6 +143,9 @@ public: testsweeper::ParamInt itermax; testsweeper::ParamChar fallback; testsweeper::ParamInt depth; + testsweeper::ParamDouble add_tol; + testsweeper::ParamChar woodbury; + testsweeper::ParamEnum< slate::BlockFactor > blockfactor; // ----- output parameters testsweeper::ParamScientific value; diff --git a/test/test_gesv.cc b/test/test_gesv.cc index 5a21e955d..05fc4bdf6 100644 --- a/test/test_gesv.cc +++ b/test/test_gesv.cc @@ -47,11 +47,11 @@ void test_gesv_work(Params& params, bool run) // get & mark input values slate::Op trans = slate::Op::NoTrans; - if (params.routine == "getrs") + if (params.routine == "getrs" || params.routine == "getrf_addmod") trans = params.trans(); int64_t m; - if (params.routine == "getrf") + if (params.routine == "getrf" || params.routine == "getrf_addmod") m = params.dim.m(); else m = params.dim.n(); // square, n-by-n @@ -87,8 +87,21 @@ void test_gesv_work(Params& params, bool run) if (! supported) timer_level = 1; + slate::BlockFactor blockFactorType = slate::BlockFactor::SVD; + if (ends_with( params.routine, "_addmod" ) + || ends_with( params.routine, "_addmod_ir" )) { + blockFactorType = params.blockfactor(); + } + // NoPiv and CALU ignore threshold. double pivot_threshold = params.pivot_threshold(); + double add_tol = 0.0; + bool useWoodbury = false; + if (params.routine == "getrf_addmod" || params.routine == "gesv_addmod" + || params.routine == "gesv_addmod_ir") { + add_tol = params.add_tol(); + useWoodbury = params.woodbury() == 'y'; + } // mark non-standard output values params.time(); @@ -96,8 +109,8 @@ void test_gesv_work(Params& params, bool run) params.ref_time(); params.ref_gflops(); - bool do_getrs = params.routine == "getrs" - || (check && params.routine == "getrf"); + bool do_getrs = params.routine == "getrs" || params.routine == "getrs_addmod" + || (check && (params.routine == "getrf" || params.routine == "getrf_addmod")); if (do_getrs) { params.time2(); @@ -136,7 +149,8 @@ void test_gesv_work(Params& params, bool run) bool is_iterative = params.routine == "gesv_mixed" || params.routine == "gesv_mixed_gmres" - || params.routine == "gesv_rbt"; + || params.routine == "gesv_rbt" + || params.routine == "gesv_addmod_ir"; int64_t itermax = 0; bool fallback = true; @@ -179,6 +193,7 @@ void test_gesv_work(Params& params, bool run) slate::Options const opts = { {slate::Option::Lookahead, lookahead}, {slate::Option::Target, target}, + {slate::Option::BlockFactor, blockFactorType}, {slate::Option::MaxPanelThreads, panel_threads}, {slate::Option::InnerBlocking, ib}, {slate::Option::PivotThreshold, pivot_threshold}, @@ -188,6 +203,8 @@ void test_gesv_work(Params& params, bool run) {slate::Option::Depth, depth}, {slate::Option::MaxIterations, itermax}, {slate::Option::UseFallbackSolver, fallback}, + {slate::Option::AdditiveTolerance, add_tol}, + {slate::Option::UseWoodbury, useWoodbury}, }; int64_t info = 0; @@ -266,6 +283,7 @@ void test_gesv_work(Params& params, bool run) } slate::Pivots pivots; + slate::AddModFactors amfactors; slate::Options matgen_opts = {{slate::Option::Target, target}}; slate::generate_matrix(params.matrix, A, matgen_opts); @@ -305,7 +323,9 @@ void test_gesv_work(Params& params, bool run) if (params.routine == "gesv" || params.routine == "gesv_mixed" || params.routine == "gesv_mixed_gmres" - || params.routine == "gesv_rbt") + || params.routine == "gesv_rbt" + || params.routine == "gesv_addmod" + || params.routine == "gesv_addmod_ir") gflop = lapack::Gflop::gesv(n, nrhs); else gflop = lapack::Gflop::getrf(m, n); @@ -351,6 +371,19 @@ void test_gesv_work(Params& params, bool run) slate::gesv_rbt(A, B, X, iters, opts); params.iters() = iters; } + else if (params.routine == "gesv_addmod") { + // slate::lu_solve_addmod(A, B, opts); + // Using traditional BLAS/LAPACK name + slate::gesv_addmod(A, amfactors, B, opts); + } + else if (params.routine == "gesv_addmod_ir") { + int iters = 0; + slate::gesv_addmod_ir(A, amfactors, B, X, iters, opts); + params.iters() = iters; + } + else if (params.routine == "getrf_addmod" || params.routine == "getrs_addmod") { + slate::getrf_addmod(A, amfactors, opts); + } time = barrier_get_wtime(MPI_COMM_WORLD) - time; // compute and save timing/performance params.time() = time; @@ -392,9 +425,21 @@ void test_gesv_work(Params& params, bool run) else if (trans == slate::Op::ConjTrans) opA = conj_transpose( A ); - slate::lu_solve_using_factor( opA, pivots, B, opts ); - // Using traditional BLAS/LAPACK name - // slate::getrs(opA, pivots, B, opts); + if (params.routine == "getrs" + || params.routine == "getrf") + { + slate::lu_solve_using_factor( opA, pivots, B, opts ); + // Using traditional BLAS/LAPACK name + // slate::getrs(opA, pivots, B, opts); + } + else if ((check && params.routine == "getrf_addmod") + || params.routine == "getrs_addmod") + { + slate::getrs_addmod(amfactors, B, opts); + } + else { + slate_error("Unknown routine!"); + } // compute and save timing/performance time2 = barrier_get_wtime(MPI_COMM_WORLD) - time2;