diff --git a/scalapack_api/scalapack_gecon.cc b/scalapack_api/scalapack_gecon.cc index cc3e7fb8..365b3423 100644 --- a/scalapack_api/scalapack_gecon.cc +++ b/scalapack_api/scalapack_gecon.cc @@ -8,91 +8,22 @@ namespace slate { namespace scalapack_api { -// ----------------------------------------------------------------------------- - -// Required CBLACS calls -extern "C" void Cblacs_gridinfo(int context, int* np_row, int* np_col, int* my_row, int* my_col); - -// Type generic function calls the SLATE routine -template< typename scalar_t > -void slate_pgecon(const char* normstr, int n, scalar_t* a, int ia, int ja, int* desca, blas::real_type anorm, blas::real_type* rcond, scalar_t* work, int lwork, void* irwork, int lirwork, int* info); - -// ----------------------------------------------------------------------------- -// C interfaces (FORTRAN_UPPER, FORTRAN_LOWER, FORTRAN_UNDERSCORE) -// Each C interface calls the type generic slate_pher2k - -extern "C" void PSGECON(const char* normstr, int* n, float* a, int* ia, int* ja, int* desca, float* anorm, float* rcond, float* work, int* lwork, int* iwork, int* liwork, int* info) -{ - slate_pgecon(normstr, *n, a, *ia, *ja, desca, *anorm, rcond, work, *lwork, iwork, *liwork, info); -} - -extern "C" void psgecon(const char* normstr, int* n, float* a, int* ia, int* ja, int* desca, float* anorm, float* rcond, float* work, int* lwork, int* iwork, int* liwork, int* info) -{ - slate_pgecon(normstr, *n, a, *ia, *ja, desca, *anorm, rcond, work, *lwork, iwork, *liwork, info); -} - -extern "C" void psgecon_(const char* normstr, int* n, float* a, int* ia, int* ja, int* desca, float* anorm, float* rcond, float* work, int* lwork, int* iwork, int* liwork, int* info) -{ - slate_pgecon(normstr, *n, a, *ia, *ja, desca, *anorm, rcond, work, *lwork, iwork, *liwork, info); -} - -// ----------------------------------------------------------------------------- - -extern "C" void PDGECON(const char* normstr, int* n, double* a, int* ia, int* ja, int* desca, double* anorm, double* rcond, double* work, int* lwork, int* iwork, int* liwork, int* info) -{ - slate_pgecon(normstr, *n, a, *ia, *ja, desca, *anorm, rcond, work, *lwork, iwork, *liwork, info); -} - -extern "C" void pdgecon(const char* normstr, int* n, double* a, int* ia, int* ja, int* desca, double* anorm, double* rcond, double* work, int* lwork, int* iwork, int* liwork, int* info) -{ - slate_pgecon(normstr, *n, a, *ia, *ja, desca, *anorm, rcond, work, *lwork, iwork, *liwork, info); -} - -extern "C" void pdgecon_(const char* normstr, int* n, double* a, int* ia, int* ja, int* desca, double* anorm, double* rcond, double* work, int* lwork, int* iwork, int* liwork, int* info) -{ - slate_pgecon(normstr, *n, a, *ia, *ja, desca, *anorm, rcond, work, *lwork, iwork, *liwork, info); -} - -// ----------------------------------------------------------------------------- - -extern "C" void PCGECON(const char* normstr, int* n, std::complex* a, int* ia, int* ja, int* desca, float* anorm, float* rcond, std::complex* work, int* lwork, float* rwork, int* lrwork, int* info) -{ - slate_pgecon(normstr, *n, a, *ia, *ja, desca, *anorm, rcond, work, *lwork, rwork, *lrwork, info); -} - -extern "C" void pcgecon(const char* normstr, int* n, std::complex* a, int* ia, int* ja, int* desca, float* anorm, float* rcond, std::complex* work, int* lwork, float* rwork, int* lrwork, int* info) -{ - slate_pgecon(normstr, *n, a, *ia, *ja, desca, *anorm, rcond, work, *lwork, rwork, *lrwork, info); -} - -extern "C" void pcgecon_(const char* normstr, int* n, std::complex* a, int* ia, int* ja, int* desca, float* anorm, float* rcond, std::complex* work, int* lwork, float* rwork, int* lrwork, int* info) -{ - slate_pgecon(normstr, *n, a, *ia, *ja, desca, *anorm, rcond, work, *lwork, rwork, *lrwork, info); -} - -// ----------------------------------------------------------------------------- - -extern "C" void PZGECON(const char* normstr, int* n, std::complex* a, int* ia, int* ja, int* desca, double* anorm, double* rcond, std::complex* work, int* lwork, double* rwork, int* lrwork, int* info) -{ - slate_pgecon(normstr, *n, a, *ia, *ja, desca, *anorm, rcond, work, *lwork, rwork, *lrwork, info); -} - -extern "C" void pzgecon(const char* normstr, int* n, std::complex* a, int* ia, int* ja, int* desca, double* anorm, double* rcond, std::complex* work, int* lwork, double* rwork, int* lrwork, int* info) -{ - slate_pgecon(normstr, *n, a, *ia, *ja, desca, *anorm, rcond, work, *lwork, rwork, *lrwork, info); -} - -extern "C" void pzgecon_(const char* normstr, int* n, std::complex* a, int* ia, int* ja, int* desca, double* anorm, double* rcond, std::complex* work, int* lwork, double* rwork, int* lrwork, int* info) -{ - slate_pgecon(normstr, *n, a, *ia, *ja, desca, *anorm, rcond, work, *lwork, rwork, *lrwork, info); -} - -// ----------------------------------------------------------------------------- -template< typename scalar_t > -void slate_pgecon(const char* normstr, int n, scalar_t* a, int ia, int ja, int* desca, blas::real_type anorm, blas::real_type* rcond, scalar_t* work, int lwork, void* irwork, int lirwork, int* info) +//------------------------------------------------------------------------------ +/// SLATE ScaLAPACK wrapper sets up SLATE matrices from ScaLAPACK descriptors +/// and calls SLATE. +/// If scalar_t is real, irwork is integer. +/// If scalar_t is complex, irwork is real. +template +void slate_pgecon( + const char* norm_str, blas_int n, + scalar_t* A_data, blas_int ia, blas_int ja, blas_int const* descA, + blas::real_type Anorm, blas::real_type* rcond, + scalar_t* work, blas_int lwork, + void* irwork, blas_int lirwork, + blas_int* info ) { Norm norm{}; - from_string( std::string( 1, normstr[0] ), &norm ); + from_string( std::string( 1, norm_str[0] ), &norm ); slate::Target target = TargetConfig::value(); int verbose = VerboseConfig::value(); @@ -104,15 +35,15 @@ void slate_pgecon(const char* normstr, int n, scalar_t* a, int ia, int ja, int* // todo: extract the real info from getrf *info = 0; - int nprow, npcol, myprow, mypcol; - Cblacs_gridinfo(desc_CTXT(desca), &nprow, &npcol, &myprow, &mypcol); + blas_int nprow, npcol, myprow, mypcol; + Cblacs_gridinfo( desc_ctxt( descA ), &nprow, &npcol, &myprow, &mypcol ); if (verbose && myprow == 0 && mypcol == 0) logprintf("%s\n", "gecon"); if (lwork == -1 || lirwork == -1) { *work = 0; if constexpr (std::is_same_v>) { - *(int*)irwork = 0; + *(blas_int*)irwork = 0; } else { *(blas::real_type*)irwork = 0; @@ -125,10 +56,13 @@ void slate_pgecon(const char* normstr, int n, scalar_t* a, int ia, int ja, int* int64_t An = n; // create SLATE matrices from the ScaLAPACK layouts - auto A = slate::Matrix::fromScaLAPACK(desc_M(desca), desc_N(desca), a, desc_LLD(desca), desc_MB(desca), desc_NB(desca), grid_order, nprow, npcol, MPI_COMM_WORLD); - A = slate_scalapack_submatrix(Am, An, A, ia, ja, desca); + auto A = slate::Matrix::fromScaLAPACK( + desc_m( descA ), desc_n( descA ), A_data, desc_lld( descA ), + desc_mb( descA ), desc_nb( descA ), + grid_order, nprow, npcol, MPI_COMM_WORLD ); + A = slate_scalapack_submatrix( Am, An, A, ia, ja, descA ); - *rcond = slate::gecondest(norm, A, anorm, { + *rcond = slate::gecondest( norm, A, Anorm, { {slate::Option::Lookahead, lookahead}, {slate::Option::Target, target}, {slate::Option::MaxPanelThreads, panel_threads}, @@ -136,5 +70,73 @@ void slate_pgecon(const char* normstr, int n, scalar_t* a, int ia, int ja, int* }); } +//------------------------------------------------------------------------------ +// Fortran interfaces +// Each Fortran interface calls the type generic slate wrapper. + +extern "C" { + +#define SCALAPACK_psgecon BLAS_FORTRAN_NAME( psgecon, PSGECON ) +void SCALAPACK_psgecon( + const char* norm, blas_int const* n, + float* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + float* Anorm, float* rcond, + float* work, blas_int const* lwork, + blas_int* iwork, blas_int const* liwork, + blas_int* info ) +{ + slate_pgecon( + norm, *n, + A_data, *ia, *ja, descA, *Anorm, rcond, + work, *lwork, iwork, *liwork, info ); +} + +#define SCALAPACK_pdgecon BLAS_FORTRAN_NAME( pdgecon, PDGECON ) +void SCALAPACK_pdgecon( + const char* norm, blas_int const* n, + double* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + double* Anorm, double* rcond, + double* work, blas_int const* lwork, + blas_int* iwork, blas_int const* liwork, + blas_int* info ) +{ + slate_pgecon( + norm, *n, + A_data, *ia, *ja, descA, *Anorm, rcond, + work, *lwork, iwork, *liwork, info ); +} + +#define SCALAPACK_pcgecon BLAS_FORTRAN_NAME( pcgecon, PCGECON ) +void SCALAPACK_pcgecon( + const char* norm, blas_int const* n, + std::complex* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + float* Anorm, float* rcond, + std::complex* work, blas_int const* lwork, + float* rwork, blas_int const* lrwork, + blas_int* info ) +{ + slate_pgecon( + norm, *n, + A_data, *ia, *ja, descA, *Anorm, rcond, + work, *lwork, rwork, *lrwork, info ); +} + +#define SCALAPACK_pzgecon BLAS_FORTRAN_NAME( pzgecon, PZGECON ) +void SCALAPACK_pzgecon( + const char* norm, blas_int const* n, + std::complex* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + double* Anorm, double* rcond, + std::complex* work, blas_int const* lwork, + double* rwork, blas_int const* lrwork, + blas_int* info ) +{ + slate_pgecon( + norm, *n, + A_data, *ia, *ja, descA, *Anorm, rcond, + work, *lwork, rwork, *lrwork, info ); +} + +} // extern "C" + } // namespace scalapack_api } // namespace slate diff --git a/scalapack_api/scalapack_gels.cc b/scalapack_api/scalapack_gels.cc index f2dd7667..15e92bd3 100644 --- a/scalapack_api/scalapack_gels.cc +++ b/scalapack_api/scalapack_gels.cc @@ -8,88 +8,16 @@ namespace slate { namespace scalapack_api { -// ----------------------------------------------------------------------------- - -// Required CBLACS calls -extern "C" void Cblacs_gridinfo(int context, int* np_row, int* np_col, int* my_row, int* my_col); - -// Type generic function calls the SLATE routine -template< typename scalar_t > -void slate_pgels(const char* transstr, int m, int n, int nrhs, scalar_t* a, int ia, int ja, int* desca, scalar_t* b, int ib, int jb, int* descb, scalar_t* work, int lwork, int* info); - -// ----------------------------------------------------------------------------- -// C interfaces (FORTRAN_UPPER, FORTRAN_LOWER, FORTRAN_UNDERSCORE) -// Each C interface calls the type generic slate_pher2k - -extern "C" void PSGELS(const char* trans, int* m, int* n, int* nrhs, float* a, int* ia, int* ja, int* desca, float* b, int* ib, int* jb, int* descb, float* work, int* lwork, int* info) -{ - slate_pgels(trans, *m, *n, *nrhs, a, *ia, *ja, desca, b, *ib, *jb, descb, work, *lwork, info); -} - -extern "C" void psgels(const char* trans, int* m, int* n, int* nrhs, float* a, int* ia, int* ja, int* desca, float* b, int* ib, int* jb, int* descb, float* work, int* lwork, int* info) -{ - slate_pgels(trans, *m, *n, *nrhs, a, *ia, *ja, desca, b, *ib, *jb, descb, work, *lwork, info); -} - -extern "C" void psgels_(const char* trans, int* m, int* n, int* nrhs, float* a, int* ia, int* ja, int* desca, float* b, int* ib, int* jb, int* descb, float* work, int* lwork, int* info) -{ - slate_pgels(trans, *m, *n, *nrhs, a, *ia, *ja, desca, b, *ib, *jb, descb, work, *lwork, info); -} - -// ----------------------------------------------------------------------------- - -extern "C" void PDGELS(const char* trans, int* m, int* n, int* nrhs, double* a, int* ia, int* ja, int* desca, double* b, int* ib, int* jb, int* descb, double* work, int* lwork, int* info) -{ - slate_pgels(trans, *m, *n, *nrhs, a, *ia, *ja, desca, b, *ib, *jb, descb, work, *lwork, info); -} - -extern "C" void pdgels(const char* trans, int* m, int* n, int* nrhs, double* a, int* ia, int* ja, int* desca, double* b, int* ib, int* jb, int* descb, double* work, int* lwork, int* info) -{ - slate_pgels(trans, *m, *n, *nrhs, a, *ia, *ja, desca, b, *ib, *jb, descb, work, *lwork, info); -} - -extern "C" void pdgels_(const char* trans, int* m, int* n, int* nrhs, double* a, int* ia, int* ja, int* desca, double* b, int* ib, int* jb, int* descb, double* work, int* lwork, int* info) -{ - slate_pgels(trans, *m, *n, *nrhs, a, *ia, *ja, desca, b, *ib, *jb, descb, work, *lwork, info); -} - -// ----------------------------------------------------------------------------- - -extern "C" void PCGELS(const char* trans, int* m, int* n, int* nrhs, std::complex* a, int* ia, int* ja, int* desca, std::complex* b, int* ib, int* jb, int* descb, std::complex* work, int* lwork, int* info) -{ - slate_pgels(trans, *m, *n, *nrhs, a, *ia, *ja, desca, b, *ib, *jb, descb, work, *lwork, info); -} - -extern "C" void pcgels(const char* trans, int* m, int* n, int* nrhs, std::complex* a, int* ia, int* ja, int* desca, std::complex* b, int* ib, int* jb, int* descb, std::complex* work, int* lwork, int* info) -{ - slate_pgels(trans, *m, *n, *nrhs, a, *ia, *ja, desca, b, *ib, *jb, descb, work, *lwork, info); -} - -extern "C" void pcgels_(const char* trans, int* m, int* n, int* nrhs, std::complex* a, int* ia, int* ja, int* desca, std::complex* b, int* ib, int* jb, int* descb, std::complex* work, int* lwork, int* info) -{ - slate_pgels(trans, *m, *n, *nrhs, a, *ia, *ja, desca, b, *ib, *jb, descb, work, *lwork, info); -} - -// ----------------------------------------------------------------------------- - -extern "C" void PZGELS(const char* trans, int* m, int* n, int* nrhs, std::complex* a, int* ia, int* ja, int* desca, std::complex* b, int* ib, int* jb, int* descb, std::complex* work, int* lwork, int* info) -{ - slate_pgels(trans, *m, *n, *nrhs, a, *ia, *ja, desca, b, *ib, *jb, descb, work, *lwork, info); -} - -extern "C" void pzgels(const char* trans, int* m, int* n, int* nrhs, std::complex* a, int* ia, int* ja, int* desca, std::complex* b, int* ib, int* jb, int* descb, std::complex* work, int* lwork, int* info) -{ - slate_pgels(trans, *m, *n, *nrhs, a, *ia, *ja, desca, b, *ib, *jb, descb, work, *lwork, info); -} - -extern "C" void pzgels_(const char* trans, int* m, int* n, int* nrhs, std::complex* a, int* ia, int* ja, int* desca, std::complex* b, int* ib, int* jb, int* descb, std::complex* work, int* lwork, int* info) -{ - slate_pgels(trans, *m, *n, *nrhs, a, *ia, *ja, desca, b, *ib, *jb, descb, work, *lwork, info); -} - -// ----------------------------------------------------------------------------- -template< typename scalar_t > -void slate_pgels(const char* transstr, int m, int n, int nrhs, scalar_t* a, int ia, int ja, int* desca, scalar_t* b, int ib, int jb, int* descb, scalar_t* work, int lwork, int* info) +//------------------------------------------------------------------------------ +/// SLATE ScaLAPACK wrapper sets up SLATE matrices from ScaLAPACK descriptors +/// and calls SLATE. +template +void slate_pgels( + const char* trans_str, blas_int m, blas_int n, blas_int nrhs, + scalar_t* A_data, blas_int ia, blas_int ja, blas_int const* descA, + scalar_t* B_data, blas_int ib, blas_int jb, blas_int const* descB, + scalar_t* work, blas_int lwork, + blas_int* info ) { using real_t = blas::real_type; @@ -102,7 +30,7 @@ void slate_pgels(const char* transstr, int m, int n, int nrhs, scalar_t* a, int } Op trans{}; - from_string( std::string( 1, transstr[0] ), &trans ); + from_string( std::string( 1, trans_str[0] ), &trans ); slate::Target target = TargetConfig::value(); int verbose = VerboseConfig::value(); @@ -120,26 +48,32 @@ void slate_pgels(const char* transstr, int m, int n, int nrhs, scalar_t* a, int int64_t Bn = nrhs; // create SLATE matrices from the ScaLAPACK layouts - int nprow, npcol, myprow, mypcol; - Cblacs_gridinfo(desc_CTXT(desca), &nprow, &npcol, &myprow, &mypcol); - auto A = slate::Matrix::fromScaLAPACK(desc_M(desca), desc_N(desca), a, desc_LLD(desca), desc_MB(desca), desc_NB(desca), grid_order, nprow, npcol, MPI_COMM_WORLD); - A = slate_scalapack_submatrix(Am, An, A, ia, ja, desca); - - Cblacs_gridinfo(desc_CTXT(descb), &nprow, &npcol, &myprow, &mypcol); - auto B = slate::Matrix::fromScaLAPACK(desc_M(descb), desc_N(descb), b, desc_LLD(descb), desc_MB(descb), desc_NB(descb), grid_order, nprow, npcol, MPI_COMM_WORLD); - B = slate_scalapack_submatrix(Bm, Bn, B, ib, jb, descb); + blas_int nprow, npcol, myprow, mypcol; + Cblacs_gridinfo( desc_ctxt( descA ), &nprow, &npcol, &myprow, &mypcol ); + auto A = slate::Matrix::fromScaLAPACK( + desc_m( descA ), desc_n( descA ), A_data, desc_lld( descA ), + desc_mb( descA ), desc_nb( descA ), + grid_order, nprow, npcol, MPI_COMM_WORLD ); + A = slate_scalapack_submatrix( Am, An, A, ia, ja, descA ); + + Cblacs_gridinfo( desc_ctxt( descB ), &nprow, &npcol, &myprow, &mypcol ); + auto B = slate::Matrix::fromScaLAPACK( + desc_m( descB ), desc_n( descB ), B_data, desc_lld( descB ), + desc_mb( descB ), desc_nb( descB ), + grid_order, nprow, npcol, MPI_COMM_WORLD ); + B = slate_scalapack_submatrix( Bm, Bn, B, ib, jb, descB ); // Apply transpose auto opA = A; if (trans == slate::Op::Trans) - opA = transpose(A); + opA = transpose( A ); else if (trans == slate::Op::ConjTrans) opA = conj_transpose( A ); if (verbose && myprow == 0 && mypcol == 0) logprintf("%s\n", "gels"); - slate::gels(opA, B, { + slate::gels( opA, B, { {slate::Option::Lookahead, lookahead}, {slate::Option::Target, target}, {slate::Option::MaxPanelThreads, panel_threads}, @@ -150,5 +84,73 @@ void slate_pgels(const char* transstr, int m, int n, int nrhs, scalar_t* a, int *info = 0; } +//------------------------------------------------------------------------------ +// Fortran interfaces +// Each Fortran interface calls the type generic slate wrapper. + +extern "C" { + +#define SCALAPACK_psgels BLAS_FORTRAN_NAME( psgels, PSGELS ) +void SCALAPACK_psgels( + const char* trans, blas_int const* m, blas_int const* n, blas_int* nrhs, + float* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + float* B_data, blas_int const* ib, blas_int const* jb, blas_int const* descB, + float* work, blas_int const* lwork, + blas_int* info ) +{ + slate_pgels( + trans, *m, *n, *nrhs, + A_data, *ia, *ja, descA, + B_data, *ib, *jb, descB, + work, *lwork, info ); +} + +#define SCALAPACK_pdgels BLAS_FORTRAN_NAME( pdgels, PDGELS ) +void SCALAPACK_pdgels( + const char* trans, blas_int const* m, blas_int const* n, blas_int* nrhs, + double* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + double* B_data, blas_int const* ib, blas_int const* jb, blas_int const* descB, + double* work, blas_int const* lwork, + blas_int* info ) +{ + slate_pgels( + trans, *m, *n, *nrhs, + A_data, *ia, *ja, descA, + B_data, *ib, *jb, descB, + work, *lwork, info ); +} + +#define SCALAPACK_pcgels BLAS_FORTRAN_NAME( pcgels, PCGELS ) +void SCALAPACK_pcgels( + const char* trans, blas_int const* m, blas_int const* n, blas_int* nrhs, + std::complex* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + std::complex* B_data, blas_int const* ib, blas_int const* jb, blas_int const* descB, + std::complex* work, blas_int const* lwork, + blas_int* info ) +{ + slate_pgels( + trans, *m, *n, *nrhs, + A_data, *ia, *ja, descA, + B_data, *ib, *jb, descB, + work, *lwork, info ); +} + +#define SCALAPACK_pzgels BLAS_FORTRAN_NAME( pzgels, PZGELS ) +void SCALAPACK_pzgels( + const char* trans, blas_int const* m, blas_int const* n, blas_int* nrhs, + std::complex* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + std::complex* B_data, blas_int const* ib, blas_int const* jb, blas_int const* descB, + std::complex* work, blas_int const* lwork, + blas_int* info ) +{ + slate_pgels( + trans, *m, *n, *nrhs, + A_data, *ia, *ja, descA, + B_data, *ib, *jb, descB, + work, *lwork, info ); +} + +} // extern "C" + } // namespace scalapack_api } // namespace slate diff --git a/scalapack_api/scalapack_gemm.cc b/scalapack_api/scalapack_gemm.cc index 226f838a..05051918 100644 --- a/scalapack_api/scalapack_gemm.cc +++ b/scalapack_api/scalapack_gemm.cc @@ -8,118 +8,22 @@ namespace slate { namespace scalapack_api { -// ----------------------------------------------------------------------------- - -// Required CBLACS calls -extern "C" void Cblacs_gridinfo(int context, int* np_row, int* np_col, int* my_row, int* my_col); - -// Declarations -template< typename scalar_t > -void slate_pgemm(const char* transastr, const char* transbstr, int m, int n, int k, scalar_t alpha, scalar_t* a, int ia, int ja, int* desca, scalar_t* b, int ib, int jb, int* descb, scalar_t beta, scalar_t* c, int ic, int jc, int* descc); - -// ----------------------------------------------------------------------------- -// Each C interface for all Fortran interfaces (FORTRAN_UPPER, FORTRAN_LOWER, FORTRAN_UNDERSCORE) -// Each C interface calls the type-generic C++ slate_pgemm routine. - -extern "C" void PDGEMM(const char* transa, const char* transb, int* m, int* n, int* k, double* alpha, double* a, int* ia, int* ja, int* desca, double* b, int* ib, int* jb, int* descb, double* beta, double* c, int* ic, int* jc, int* descc) -{ - slate_pgemm(transa, transb, *m, *n, *k, *alpha, a, *ia, *ja, desca, b, *ib, *jb, descb, *beta, c, *ic, *jc, descc); -} - -extern "C" void pdgemm(const char* transa, const char* transb, int* m, int* n, int* k, double* alpha, double* a, int* ia, int* ja, int* desca, double* b, int* ib, int* jb, int* descb, double* beta, double* c, int* ic, int* jc, int* descc) -{ - slate_pgemm(transa, transb, *m, *n, *k, *alpha, a, *ia, *ja, desca, b, *ib, *jb, descb, *beta, c, *ic, *jc, descc); -} - -extern "C" void pdgemm_(const char* transa, const char* transb, int* m, int* n, int* k, double* alpha, double* a, int* ia, int* ja, int* desca, double* b, int* ib, int* jb, int* descb, double* beta, double* c, int* ic, int* jc, int* descc) -{ - slate_pgemm(transa, transb, *m, *n, *k, *alpha, a, *ia, *ja, desca, b, *ib, *jb, descb, *beta, c, *ic, *jc, descc); -} - -// ----------------------------------------------------------------------------- - -extern "C" void PSGEMM(const char* transa, const char* transb, int* m, int* n, int* k, float* alpha, float* a, int* ia, int* ja, int* desca, float* b, int* ib, int* jb, int* descb, float* beta, float* c, int* ic, int* jc, int* descc) -{ - slate_pgemm(transa, transb, *m, *n, *k, *alpha, a, *ia, *ja, desca, b, *ib, *jb, descb, *beta, c, *ic, *jc, descc); -} - -extern "C" void psgemm(const char* transa, const char* transb, int* m, int* n, int* k, float* alpha, float* a, int* ia, int* ja, int* desca, float* b, int* ib, int* jb, int* descb, float* beta, float* c, int* ic, int* jc, int* descc) -{ - slate_pgemm(transa, transb, *m, *n, *k, *alpha, a, *ia, *ja, desca, b, *ib, *jb, descb, *beta, c, *ic, *jc, descc); -} - -extern "C" void psgemm_(const char* transa, const char* transb, int* m, int* n, int* k, float* alpha, float* a, int* ia, int* ja, int* desca, float* b, int* ib, int* jb, int* descb, float* beta, float* c, int* ic, int* jc, int* descc) -{ - slate_pgemm(transa, transb, *m, *n, *k, *alpha, a, *ia, *ja, desca, b, *ib, *jb, descb, *beta, c, *ic, *jc, descc); -} - -// ----------------------------------------------------------------------------- - -extern "C" void PCGEMM(const char* transa, const char* transb, int* m, int* n, int* k, std::complex* alpha, std::complex* a, int* ia, int* ja, int* desca, std::complex* b, int* ib, int* jb, int* descb, std::complex* beta, std::complex* c, int* ic, int* jc, int* descc) -{ - slate_pgemm(transa, transb, *m, *n, *k, *alpha, a, *ia, *ja, desca, b, *ib, *jb, descb, *beta, c, *ic, *jc, descc); -} - -extern "C" void pcgemm(const char* transa, const char* transb, int* m, int* n, int* k, std::complex* alpha, std::complex* a, int* ia, int* ja, int* desca, std::complex* b, int* ib, int* jb, int* descb, std::complex* beta, std::complex* c, int* ic, int* jc, int* descc) -{ - slate_pgemm(transa, transb, *m, *n, *k, *alpha, a, *ia, *ja, desca, b, *ib, *jb, descb, *beta, c, *ic, *jc, descc); -} - -extern "C" void pcgemm_(const char* transa, const char* transb, int* m, int* n, int* k, std::complex* alpha, std::complex* a, int* ia, int* ja, int* desca, std::complex* b, int* ib, int* jb, int* descb, std::complex* beta, std::complex* c, int* ic, int* jc, int* descc) -{ - slate_pgemm(transa, transb, *m, *n, *k, *alpha, a, *ia, *ja, desca, b, *ib, *jb, descb, *beta, c, *ic, *jc, descc); -} - -// ----------------------------------------------------------------------------- - -extern "C" void PZGEMM(const char* transa, const char* transb, int* m, int* n, int* k, std::complex* alpha, std::complex* a, int* ia, int* ja, int* desca, std::complex* b, int* ib, int* jb, int* descb, std::complex* beta, std::complex* c, int* ic, int* jc, int* descc) -{ - slate_pgemm(transa, transb, *m, *n, *k, *alpha, a, *ia, *ja, desca, b, *ib, *jb, descb, *beta, c, *ic, *jc, descc); -} - -extern "C" void pzgemm(const char* transa, const char* transb, int* m, int* n, int* k, std::complex* alpha, std::complex* a, int* ia, int* ja, int* desca, std::complex* b, int* ib, int* jb, int* descb, std::complex* beta, std::complex* c, int* ic, int* jc, int* descc) -{ - slate_pgemm(transa, transb, *m, *n, *k, *alpha, a, *ia, *ja, desca, b, *ib, *jb, descb, *beta, c, *ic, *jc, descc); -} - -extern "C" void pzgemm_(const char* transa, const char* transb, int* m, int* n, int* k, std::complex* alpha, std::complex* a, int* ia, int* ja, int* desca, std::complex* b, int* ib, int* jb, int* descb, std::complex* beta, std::complex* c, int* ic, int* jc, int* descc) -{ - slate_pgemm(transa, transb, *m, *n, *k, *alpha, a, *ia, *ja, desca, b, *ib, *jb, descb, *beta, c, *ic, *jc, descc); -} - -// ----------------------------------------------------------------------------- -// Exposed type-specific API - -extern "C" void slate_pdgemm(const char* transa, const char* transb, int* m, int* n, int* k, double* alpha, double* a, int* ia, int* ja, int* desca, double* b, int* ib, int* jb, int* descb, double* beta, double* c, int* ic, int* jc, int* descc) -{ - slate_pgemm(transa, transb, *m, *n, *k, *alpha, a, *ia, *ja, desca, b, *ib, *jb, descb, *beta, c, *ic, *jc, descc); -} - -extern "C" void slate_psgemm(const char* transa, const char* transb, int* m, int* n, int* k, float* alpha, float* a, int* ia, int* ja, int* desca, float* b, int* ib, int* jb, int* descb, float* beta, float* c, int* ic, int* jc, int* descc) -{ - slate_pgemm(transa, transb, *m, *n, *k, *alpha, a, *ia, *ja, desca, b, *ib, *jb, descb, *beta, c, *ic, *jc, descc); -} - -extern "C" void slate_pcgemm(const char* transa, const char* transb, int* m, int* n, int* k, std::complex* alpha, std::complex* a, int* ia, int* ja, int* desca, std::complex* b, int* ib, int* jb, int* descb, std::complex* beta, std::complex* c, int* ic, int* jc, int* descc) -{ - slate_pgemm(transa, transb, *m, *n, *k, *alpha, a, *ia, *ja, desca, b, *ib, *jb, descb, *beta, c, *ic, *jc, descc); -} - -extern "C" void slate_pzgemm(const char* transa, const char* transb, int* m, int* n, int* k, std::complex* alpha, std::complex* a, int* ia, int* ja, int* desca, std::complex* b, int* ib, int* jb, int* descb, std::complex* beta, std::complex* c, int* ic, int* jc, int* descc) -{ - slate_pgemm(transa, transb, *m, *n, *k, *alpha, a, *ia, *ja, desca, b, *ib, *jb, descb, *beta, c, *ic, *jc, descc); -} - -// ----------------------------------------------------------------------------- - -// Type generic function calls the SLATE routine -template< typename scalar_t > -void slate_pgemm(const char* transastr, const char* transbstr, int m, int n, int k, scalar_t alpha, scalar_t* a, int ia, int ja, int* desca, scalar_t* b, int ib, int jb, int* descb, scalar_t beta, scalar_t* c, int ic, int jc, int* descc) +//------------------------------------------------------------------------------ +/// SLATE ScaLAPACK wrapper sets up SLATE matrices from ScaLAPACK descriptors +/// and calls SLATE. +template +void slate_pgemm( + const char* transA_str, const char* transB_str, + blas_int m, blas_int n, blas_int k, scalar_t alpha, + scalar_t* A_data, blas_int ia, blas_int ja, blas_int const* descA, + scalar_t* B_data, blas_int ib, blas_int jb, blas_int const* descB, + scalar_t beta, + scalar_t* C_data, blas_int ic, blas_int jc, blas_int const* descC ) { Op transA{}; Op transB{}; - from_string( std::string( 1, transastr[0] ), &transA ); - from_string( std::string( 1, transbstr[0] ), &transB ); + from_string( std::string( 1, transA_str[0] ), &transA ); + from_string( std::string( 1, transB_str[0] ), &transB ); slate::Target target = TargetConfig::value(); int verbose = VerboseConfig::value(); @@ -135,37 +39,122 @@ void slate_pgemm(const char* transastr, const char* transbstr, int m, int n, int int64_t Cn = n; // create SLATE matrices from the ScaLAPACK layouts - int nprow, npcol, myprow, mypcol; - Cblacs_gridinfo(desc_CTXT(desca), &nprow, &npcol, &myprow, &mypcol); - auto A = slate::Matrix::fromScaLAPACK(desc_M(desca), desc_N(desca), a, desc_LLD(desca), desc_MB(desca), desc_NB(desca), grid_order, nprow, npcol, MPI_COMM_WORLD); - A = slate_scalapack_submatrix(Am, An, A, ia, ja, desca); - - Cblacs_gridinfo(desc_CTXT(descb), &nprow, &npcol, &myprow, &mypcol); - auto B = slate::Matrix::fromScaLAPACK(desc_M(descb), desc_N(descb), b, desc_LLD(descb), desc_MB(descb), desc_NB(descb), grid_order, nprow, npcol, MPI_COMM_WORLD); - B = slate_scalapack_submatrix(Bm, Bn, B, ib, jb, descb); - - Cblacs_gridinfo(desc_CTXT(descc), &nprow, &npcol, &myprow, &mypcol); - auto C = slate::Matrix::fromScaLAPACK(desc_M(descc), desc_N(descc), c, desc_LLD(descc), desc_MB(descc), desc_NB(descc), grid_order, nprow, npcol, MPI_COMM_WORLD); - C = slate_scalapack_submatrix(Cm, Cn, C, ic, jc, descc); + blas_int nprow, npcol, myprow, mypcol; + Cblacs_gridinfo( desc_ctxt( descA ), &nprow, &npcol, &myprow, &mypcol ); + auto A = slate::Matrix::fromScaLAPACK( + desc_m( descA ), desc_n( descA ), A_data, desc_lld( descA ), + desc_mb( descA ), desc_nb( descA ), + grid_order, nprow, npcol, MPI_COMM_WORLD ); + A = slate_scalapack_submatrix( Am, An, A, ia, ja, descA ); + + Cblacs_gridinfo( desc_ctxt( descB ), &nprow, &npcol, &myprow, &mypcol ); + auto B = slate::Matrix::fromScaLAPACK( + desc_m( descB ), desc_n( descB ), B_data, desc_lld( descB ), + desc_mb( descB ), desc_nb( descB ), + grid_order, nprow, npcol, MPI_COMM_WORLD ); + B = slate_scalapack_submatrix( Bm, Bn, B, ib, jb, descB ); + + Cblacs_gridinfo( desc_ctxt( descC ), &nprow, &npcol, &myprow, &mypcol ); + auto C = slate::Matrix::fromScaLAPACK( + desc_m( descC ), desc_n( descC ), C_data, desc_lld( descC ), + desc_mb( descC ), desc_nb( descC ), + grid_order, nprow, npcol, MPI_COMM_WORLD ); + C = slate_scalapack_submatrix( Cm, Cn, C, ic, jc, descC ); if (transA == blas::Op::Trans) - A = transpose(A); + A = transpose( A ); else if (transA == blas::Op::ConjTrans) A = conj_transpose( A ); if (transB == blas::Op::Trans) - B = transpose(B); + B = transpose( B ); else if (transB == blas::Op::ConjTrans) B = conj_transpose( B ); if (verbose && myprow == 0 && mypcol == 0) logprintf("%s\n", "gemm"); - slate::gemm(alpha, A, B, beta, C, { + slate::gemm( alpha, A, B, beta, C, { {slate::Option::Lookahead, lookahead}, {slate::Option::Target, target} }); } +//------------------------------------------------------------------------------ +// Fortran interfaces +// Each Fortran interface calls the type-generic C++ slate_pgemm routine. + +extern "C" { + +#define SCALAPACK_psgemm BLAS_FORTRAN_NAME( psgemm, PSGEMM ) +void SCALAPACK_psgemm( + const char* transA, const char* transB, + blas_int const* m, blas_int const* n, blas_int const* k, + float* alpha, + float* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + float* B_data, blas_int const* ib, blas_int const* jb, blas_int const* descB, + float* beta, + float* C_data, blas_int const* ic, blas_int const* jc, blas_int const* descC ) +{ + slate_pgemm( + transA, transB, *m, *n, *k, *alpha, + A_data, *ia, *ja, descA, + B_data, *ib, *jb, descB, *beta, + C_data, *ic, *jc, descC ); +} + +#define SCALAPACK_pdgemm BLAS_FORTRAN_NAME( pdgemm, PDGEMM ) +void SCALAPACK_pdgemm( + const char* transA, const char* transB, + blas_int const* m, blas_int const* n, blas_int const* k, + double* alpha, + double* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + double* B_data, blas_int const* ib, blas_int const* jb, blas_int const* descB, + double* beta, + double* C_data, blas_int const* ic, blas_int const* jc, blas_int const* descC ) +{ + slate_pgemm( + transA, transB, *m, *n, *k, *alpha, + A_data, *ia, *ja, descA, + B_data, *ib, *jb, descB, *beta, + C_data, *ic, *jc, descC ); +} + +#define SCALAPACK_pcgemm BLAS_FORTRAN_NAME( pcgemm, PCGEMM ) +void SCALAPACK_pcgemm( + const char* transA, const char* transB, + blas_int const* m, blas_int const* n, blas_int const* k, + std::complex* alpha, + std::complex* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + std::complex* B_data, blas_int const* ib, blas_int const* jb, blas_int const* descB, + std::complex* beta, + std::complex* C_data, blas_int const* ic, blas_int const* jc, blas_int const* descC ) +{ + slate_pgemm( + transA, transB, *m, *n, *k, *alpha, + A_data, *ia, *ja, descA, + B_data, *ib, *jb, descB, *beta, + C_data, *ic, *jc, descC ); +} + +#define SCALAPACK_pzgemm BLAS_FORTRAN_NAME( pzgemm, PZGEMM ) +void SCALAPACK_pzgemm( + const char* transA, const char* transB, + blas_int const* m, blas_int const* n, blas_int const* k, + std::complex* alpha, + std::complex* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + std::complex* B_data, blas_int const* ib, blas_int const* jb, blas_int const* descB, + std::complex* beta, + std::complex* C_data, blas_int const* ic, blas_int const* jc, blas_int const* descC ) +{ + slate_pgemm( + transA, transB, *m, *n, *k, *alpha, + A_data, *ia, *ja, descA, + B_data, *ib, *jb, descB, *beta, + C_data, *ic, *jc, descC ); +} + +} // extern "C" + } // namespace scalapack_api } // namespace slate diff --git a/scalapack_api/scalapack_gesv.cc b/scalapack_api/scalapack_gesv.cc index 9ed71ed1..26a8f4b2 100644 --- a/scalapack_api/scalapack_gesv.cc +++ b/scalapack_api/scalapack_gesv.cc @@ -8,88 +8,16 @@ namespace slate { namespace scalapack_api { -// ----------------------------------------------------------------------------- - -// Required CBLACS calls -extern "C" void Cblacs_gridinfo(int context, int* np_row, int* np_col, int* my_row, int* my_col); - -// Type generic function calls the SLATE routine -template< typename scalar_t > -void slate_pgesv(int n, int nrhs, scalar_t* a, int ia, int ja, int* desca, int* ipiv, scalar_t* b, int ib, int jb, int* descb, int* info); - -// ----------------------------------------------------------------------------- -// C interfaces (FORTRAN_UPPER, FORTRAN_LOWER, FORTRAN_UNDERSCORE) -// Each C interface calls the type generic slate_pher2k - -extern "C" void PSGESV(int* n, int* nrhs, float* a, int* ia, int* ja, int* desca, int* ipiv, float* b, int* ib, int* jb, int* descb, int* info) -{ - slate_pgesv(*n, *nrhs, a, *ia, *ja, desca, ipiv, b, *ib, *jb, descb, info); -} - -extern "C" void psgesv(int* n, int* nrhs, float* a, int* ia, int* ja, int* desca, int* ipiv, float* b, int* ib, int* jb, int* descb, int* info) -{ - slate_pgesv(*n, *nrhs, a, *ia, *ja, desca, ipiv, b, *ib, *jb, descb, info); -} - -extern "C" void psgesv_(int* n, int* nrhs, float* a, int* ia, int* ja, int* desca, int* ipiv, float* b, int* ib, int* jb, int* descb, int* info) -{ - slate_pgesv(*n, *nrhs, a, *ia, *ja, desca, ipiv, b, *ib, *jb, descb, info); -} - -// ----------------------------------------------------------------------------- - -extern "C" void PDGESV(int* n, int* nrhs, double* a, int* ia, int* ja, int* desca, int* ipiv, double* b, int* ib, int* jb, int* descb, int* info) -{ - slate_pgesv(*n, *nrhs, a, *ia, *ja, desca, ipiv, b, *ib, *jb, descb, info); -} - -extern "C" void pdgesv(int* n, int* nrhs, double* a, int* ia, int* ja, int* desca, int* ipiv, double* b, int* ib, int* jb, int* descb, int* info) -{ - slate_pgesv(*n, *nrhs, a, *ia, *ja, desca, ipiv, b, *ib, *jb, descb, info); -} - -extern "C" void pdgesv_(int* n, int* nrhs, double* a, int* ia, int* ja, int* desca, int* ipiv, double* b, int* ib, int* jb, int* descb, int* info) -{ - slate_pgesv(*n, *nrhs, a, *ia, *ja, desca, ipiv, b, *ib, *jb, descb, info); -} - -// ----------------------------------------------------------------------------- - -extern "C" void PCGESV(int* n, int* nrhs, std::complex* a, int* ia, int* ja, int* desca, int* ipiv, std::complex* b, int* ib, int* jb, int* descb, int* info) -{ - slate_pgesv(*n, *nrhs, a, *ia, *ja, desca, ipiv, b, *ib, *jb, descb, info); -} - -extern "C" void pcgesv(int* n, int* nrhs, std::complex* a, int* ia, int* ja, int* desca, int* ipiv, std::complex* b, int* ib, int* jb, int* descb, int* info) -{ - slate_pgesv(*n, *nrhs, a, *ia, *ja, desca, ipiv, b, *ib, *jb, descb, info); -} - -extern "C" void pcgesv_(int* n, int* nrhs, std::complex* a, int* ia, int* ja, int* desca, int* ipiv, std::complex* b, int* ib, int* jb, int* descb, int* info) -{ - slate_pgesv(*n, *nrhs, a, *ia, *ja, desca, ipiv, b, *ib, *jb, descb, info); -} - -// ----------------------------------------------------------------------------- - -extern "C" void PZGESV(int* n, int* nrhs, std::complex* a, int* ia, int* ja, int* desca, int* ipiv, std::complex* b, int* ib, int* jb, int* descb, int* info) -{ - slate_pgesv(*n, *nrhs, a, *ia, *ja, desca, ipiv, b, *ib, *jb, descb, info); -} - -extern "C" void pzgesv(int* n, int* nrhs, std::complex* a, int* ia, int* ja, int* desca, int* ipiv, std::complex* b, int* ib, int* jb, int* descb, int* info) -{ - slate_pgesv(*n, *nrhs, a, *ia, *ja, desca, ipiv, b, *ib, *jb, descb, info); -} - -extern "C" void pzgesv_(int* n, int* nrhs, std::complex* a, int* ia, int* ja, int* desca, int* ipiv, std::complex* b, int* ib, int* jb, int* descb, int* info) -{ - slate_pgesv(*n, *nrhs, a, *ia, *ja, desca, ipiv, b, *ib, *jb, descb, info); -} - -// ----------------------------------------------------------------------------- -template< typename scalar_t > -void slate_pgesv(int n, int nrhs, scalar_t* a, int ia, int ja, int* desca, int* ipiv, scalar_t* b, int ib, int jb, int* descb, int* info) +//------------------------------------------------------------------------------ +/// SLATE ScaLAPACK wrapper sets up SLATE matrices from ScaLAPACK descriptors +/// and calls SLATE. +template +void slate_pgesv( + blas_int n, blas_int nrhs, + scalar_t* A_data, blas_int ia, blas_int ja, blas_int const* descA, + blas_int* ipiv, + scalar_t* B_data, blas_int ib, blas_int jb, blas_int const* descB, + blas_int* info ) { slate::Target target = TargetConfig::value(); int verbose = VerboseConfig::value(); @@ -106,19 +34,25 @@ void slate_pgesv(int n, int nrhs, scalar_t* a, int ia, int ja, int* desca, int* slate::Pivots pivots; // create SLATE matrices from the ScaLAPACK layouts - int nprow, npcol, myprow, mypcol; - Cblacs_gridinfo(desc_CTXT(desca), &nprow, &npcol, &myprow, &mypcol); - auto A = slate::Matrix::fromScaLAPACK(desc_M(desca), desc_N(desca), a, desc_LLD(desca), desc_MB(desca), desc_NB(desca), grid_order, nprow, npcol, MPI_COMM_WORLD); - A = slate_scalapack_submatrix(Am, An, A, ia, ja, desca); - - Cblacs_gridinfo(desc_CTXT(descb), &nprow, &npcol, &myprow, &mypcol); - auto B = slate::Matrix::fromScaLAPACK(desc_M(descb), desc_N(descb), b, desc_LLD(descb), desc_MB(descb), desc_NB(descb), grid_order, nprow, npcol, MPI_COMM_WORLD); - B = slate_scalapack_submatrix(Bm, Bn, B, ib, jb, descb); + blas_int nprow, npcol, myprow, mypcol; + Cblacs_gridinfo( desc_ctxt( descA ), &nprow, &npcol, &myprow, &mypcol ); + auto A = slate::Matrix::fromScaLAPACK( + desc_m( descA ), desc_n( descA ), A_data, desc_lld( descA ), + desc_mb( descA ), desc_nb( descA ), + grid_order, nprow, npcol, MPI_COMM_WORLD ); + A = slate_scalapack_submatrix( Am, An, A, ia, ja, descA ); + + Cblacs_gridinfo( desc_ctxt( descB ), &nprow, &npcol, &myprow, &mypcol ); + auto B = slate::Matrix::fromScaLAPACK( + desc_m( descB ), desc_n( descB ), B_data, desc_lld( descB ), + desc_mb( descB ), desc_nb( descB ), + grid_order, nprow, npcol, MPI_COMM_WORLD ); + B = slate_scalapack_submatrix( Bm, Bn, B, ib, jb, descB ); if (verbose && myprow == 0 && mypcol == 0) logprintf("%s\n", "gesv"); - slate::gesv(A, pivots, B, { + slate::gesv( A, pivots, B, { {slate::Option::Lookahead, lookahead}, {slate::Option::Target, target}, {slate::Option::MaxPanelThreads, panel_threads}, @@ -127,14 +61,14 @@ void slate_pgesv(int n, int nrhs, scalar_t* a, int ia, int ja, int* desca, int* // Extract pivots from SLATE's global Pivots structure into ScaLAPACK local ipiv array { - int isrcproc0 = 0; - int nb = desc_MB(desca); // ScaLAPACK style fixed nb - int64_t l_numrows = scalapack_numroc(An, nb, myprow, isrcproc0, nprow); + blas_int isrcproc0 = 0; + blas_int nb = desc_mb( descA ); // ScaLAPACK style fixed nb + int64_t l_numrows = scalapack_numroc( An, nb, myprow, isrcproc0, nprow ); // l_ipiv_rindx local ipiv row index (Scalapack 1-index) // for each local ipiv entry, find corresponding local-pivot and swap-pivot - for (int l_ipiv_rindx=1; l_ipiv_rindx <= l_numrows; ++l_ipiv_rindx) { + for (blas_int l_ipiv_rindx=1; l_ipiv_rindx <= l_numrows; ++l_ipiv_rindx) { // for ipiv index, convert to global indexing - int64_t g_ipiv_rindx = scalapack_indxl2g(&l_ipiv_rindx, &nb, &myprow, &isrcproc0, &nprow); + int64_t g_ipiv_rindx = scalapack_indxl2g( &l_ipiv_rindx, &nb, &myprow, &isrcproc0, &nprow ); // assuming uniform nb from scalapack (note 1-indexing) // figure out pivots(tile-index, offset) int64_t g_ipiv_tile_indx = (g_ipiv_rindx - 1) / nb; @@ -154,5 +88,65 @@ void slate_pgesv(int n, int nrhs, scalar_t* a, int ia, int ja, int* desca, int* *info = 0; } +//------------------------------------------------------------------------------ +// Fortran interfaces +// Each Fortran interface calls the type generic slate wrapper. + +extern "C" { + +#define SCALAPACK_psgesv BLAS_FORTRAN_NAME( psgesv, PSGESV ) +void SCALAPACK_psgesv( + blas_int const* n, blas_int* nrhs, + float* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + blas_int* ipiv, + float* B_data, blas_int const* ib, blas_int const* jb, blas_int const* descB, + blas_int* info ) +{ + slate_pgesv( *n, *nrhs, + A_data, *ia, *ja, descA, ipiv, + B_data, *ib, *jb, descB, info ); +} + +#define SCALAPACK_pdgesv BLAS_FORTRAN_NAME( pdgesv, PDGESV ) +void SCALAPACK_pdgesv( + blas_int const* n, blas_int* nrhs, + double* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + blas_int* ipiv, + double* B_data, blas_int const* ib, blas_int const* jb, blas_int const* descB, + blas_int* info ) +{ + slate_pgesv( *n, *nrhs, + A_data, *ia, *ja, descA, ipiv, + B_data, *ib, *jb, descB, info ); +} + +#define SCALAPACK_pcgesv BLAS_FORTRAN_NAME( pcgesv, PCGESV ) +void SCALAPACK_pcgesv( + blas_int const* n, blas_int* nrhs, + std::complex* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + blas_int* ipiv, + std::complex* B_data, blas_int const* ib, blas_int const* jb, blas_int const* descB, + blas_int* info ) +{ + slate_pgesv( *n, *nrhs, + A_data, *ia, *ja, descA, ipiv, + B_data, *ib, *jb, descB, info ); +} + +#define SCALAPACK_pzgesv BLAS_FORTRAN_NAME( pzgesv, PZGESV ) +void SCALAPACK_pzgesv( + blas_int const* n, blas_int* nrhs, + std::complex* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + blas_int* ipiv, + std::complex* B_data, blas_int const* ib, blas_int const* jb, blas_int const* descB, + blas_int* info ) +{ + slate_pgesv( *n, *nrhs, + A_data, *ia, *ja, descA, ipiv, + B_data, *ib, *jb, descB, info ); +} + +} // extern "C" + } // namespace scalapack_api } // namespace slate diff --git a/scalapack_api/scalapack_gesv_mixed.cc b/scalapack_api/scalapack_gesv_mixed.cc index 527b29cc..7daad831 100644 --- a/scalapack_api/scalapack_gesv_mixed.cc +++ b/scalapack_api/scalapack_gesv_mixed.cc @@ -8,56 +8,17 @@ namespace slate { namespace scalapack_api { -// ----------------------------------------------------------------------------- - -// Required CBLACS calls -extern "C" void Cblacs_gridinfo(int context, int* np_row, int* np_col, int* my_row, int* my_col); - -// Type generic function calls the SLATE routine -template< typename scalar_t > -void slate_pgesv_mixed(int n, int nrhs, scalar_t* a, int ia, int ja, int* desca, int* ipiv, scalar_t* b, int ib, int jb, int* descb, scalar_t* x, int ix, int jx, int* descx, int* iter, int* info); - -// ----------------------------------------------------------------------------- -// C interfaces (FORTRAN_UPPER, FORTRAN_LOWER, FORTRAN_UNDERSCORE) -// Each C interface calls the type generic slate_pher2k - -// ----------------------------------------------------------------------------- - -extern "C" void PDSGESV(int* n, int* nrhs, double* a, int* ia, int* ja, int* desca, int* ipiv, double* b, int* ib, int* jb, int* descb, double* x, int* ix, int* jx, int* descx, int* iter, int* info) -{ - slate_pgesv_mixed(*n, *nrhs, a, *ia, *ja, desca, ipiv, b, *ib, *jb, descb, x, *ix, *jx, descx, iter, info); -} - -extern "C" void pdsgesv(int* n, int* nrhs, double* a, int* ia, int* ja, int* desca, int* ipiv, double* b, int* ib, int* jb, int* descb, double* x, int* ix, int* jx, int* descx, int* iter, int* info) -{ - slate_pgesv_mixed(*n, *nrhs, a, *ia, *ja, desca, ipiv, b, *ib, *jb, descb, x, *ix, *jx, descx, iter, info); -} - -extern "C" void pdsgesv_(int* n, int* nrhs, double* a, int* ia, int* ja, int* desca, int* ipiv, double* b, int* ib, int* jb, int* descb, double* x, int* ix, int* jx, int* descx, int* iter, int* info) -{ - slate_pgesv_mixed(*n, *nrhs, a, *ia, *ja, desca, ipiv, b, *ib, *jb, descb, x, *ix, *jx, descx, iter, info); -} - -// ----------------------------------------------------------------------------- - -extern "C" void PZCGESV(int* n, int* nrhs, std::complex* a, int* ia, int* ja, int* desca, int* ipiv, std::complex* b, int* ib, int* jb, int* descb, std::complex* x, int* ix, int* jx, int* descx, int* iter, int* info) -{ - slate_pgesv_mixed(*n, *nrhs, a, *ia, *ja, desca, ipiv, b, *ib, *jb, descb, x, *ix, *jx, descx, iter, info); -} - -extern "C" void pzcgesv(int* n, int* nrhs, std::complex* a, int* ia, int* ja, int* desca, int* ipiv, std::complex* b, int* ib, int* jb, int* descb, std::complex* x, int* ix, int* jx, int* descx, int* iter, int* info) -{ - slate_pgesv_mixed(*n, *nrhs, a, *ia, *ja, desca, ipiv, b, *ib, *jb, descb, x, *ix, *jx, descx, iter, info); -} - -extern "C" void pzcgesv_(int* n, int* nrhs, std::complex* a, int* ia, int* ja, int* desca, int* ipiv, std::complex* b, int* ib, int* jb, int* descb, std::complex* x, int* ix, int* jx, int* descx, int* iter, int* info) -{ - slate_pgesv_mixed(*n, *nrhs, a, *ia, *ja, desca, ipiv, b, *ib, *jb, descb, x, *ix, *jx, descx, iter, info); -} - -// ----------------------------------------------------------------------------- -template< typename scalar_t > -void slate_pgesv_mixed(int n, int nrhs, scalar_t* a, int ia, int ja, int* desca, int* ipiv, scalar_t* b, int ib, int jb, int* descb, scalar_t* x, int ix, int jx, int* descx, int* iter, int* info) +//------------------------------------------------------------------------------ +/// SLATE ScaLAPACK wrapper sets up SLATE matrices from ScaLAPACK descriptors +/// and calls SLATE. +template +void slate_pgesv_mixed( + blas_int n, blas_int nrhs, + scalar_t* A_data, blas_int ia, blas_int ja, blas_int const* descA, + blas_int* ipiv, + scalar_t* B_data, blas_int ib, blas_int jb, blas_int const* descB, + scalar_t* x, blas_int ix, blas_int jx, blas_int const* descX, blas_int* iter, + blas_int* info ) { using real_t = blas::real_type; @@ -78,18 +39,27 @@ void slate_pgesv_mixed(int n, int nrhs, scalar_t* a, int ia, int ja, int* desca, slate::Pivots pivots; // create SLATE matrices from the ScaLAPACK layouts - int nprow, npcol, myprow, mypcol; - Cblacs_gridinfo(desc_CTXT(desca), &nprow, &npcol, &myprow, &mypcol); - auto A = slate::Matrix::fromScaLAPACK(desc_M(desca), desc_N(desca), a, desc_LLD(desca), desc_MB(desca), desc_NB(desca), grid_order, nprow, npcol, MPI_COMM_WORLD); - A = slate_scalapack_submatrix(Am, An, A, ia, ja, desca); - - Cblacs_gridinfo(desc_CTXT(descb), &nprow, &npcol, &myprow, &mypcol); - auto B = slate::Matrix::fromScaLAPACK(desc_M(descb), desc_N(descb), b, desc_LLD(descb), desc_MB(descb), desc_NB(descb), grid_order, nprow, npcol, MPI_COMM_WORLD); - B = slate_scalapack_submatrix(Bm, Bn, B, ib, jb, descb); - - Cblacs_gridinfo(desc_CTXT(descx), &nprow, &npcol, &myprow, &mypcol); - auto X = slate::Matrix::fromScaLAPACK(desc_M(descx), desc_N(descx), x, desc_LLD(descx), desc_MB(descx), desc_NB(descb), grid_order, nprow, npcol, MPI_COMM_WORLD); - X = slate_scalapack_submatrix(Xm, Xn, X, ix, jx, descx); + blas_int nprow, npcol, myprow, mypcol; + Cblacs_gridinfo( desc_ctxt( descA ), &nprow, &npcol, &myprow, &mypcol ); + auto A = slate::Matrix::fromScaLAPACK( + desc_m( descA ), desc_n( descA ), A_data, desc_lld( descA ), + desc_mb( descA ), desc_nb( descA ), + grid_order, nprow, npcol, MPI_COMM_WORLD ); + A = slate_scalapack_submatrix( Am, An, A, ia, ja, descA ); + + Cblacs_gridinfo( desc_ctxt( descB ), &nprow, &npcol, &myprow, &mypcol ); + auto B = slate::Matrix::fromScaLAPACK( + desc_m( descB ), desc_n( descB ), B_data, desc_lld( descB ), + desc_mb( descB ), desc_nb( descB ), + grid_order, nprow, npcol, MPI_COMM_WORLD ); + B = slate_scalapack_submatrix( Bm, Bn, B, ib, jb, descB ); + + Cblacs_gridinfo( desc_ctxt( descX ), &nprow, &npcol, &myprow, &mypcol ); + auto X = slate::Matrix::fromScaLAPACK( + desc_m( descX ), desc_n( descX ), x, desc_lld( descX ), + desc_mb( descX ), desc_nb( descB ), + grid_order, nprow, npcol, MPI_COMM_WORLD ); + X = slate_scalapack_submatrix( Xm, Xn, X, ix, jx, descX ); if (verbose && myprow == 0 && mypcol == 0) logprintf("%s\n", "gesv_mixed"); @@ -106,14 +76,14 @@ void slate_pgesv_mixed(int n, int nrhs, scalar_t* a, int ia, int ja, int* desca, // Extract pivots from SLATE's global Pivots structure into ScaLAPACK local ipiv array { - int isrcproc0 = 0; - int nb = desc_MB(desca); // ScaLAPACK style fixed nb - int64_t l_numrows = scalapack_numroc(An, nb, myprow, isrcproc0, nprow); + blas_int isrcproc0 = 0; + blas_int nb = desc_mb( descA ); // ScaLAPACK style fixed nb + int64_t l_numrows = scalapack_numroc( An, nb, myprow, isrcproc0, nprow ); // l_ipiv_rindx local ipiv row index (Scalapack 1-index) // for each local ipiv entry, find corresponding local-pivot and swap-pivot - for (int l_ipiv_rindx=1; l_ipiv_rindx <= l_numrows; ++l_ipiv_rindx) { + for (blas_int l_ipiv_rindx=1; l_ipiv_rindx <= l_numrows; ++l_ipiv_rindx) { // for ipiv index, convert to global indexing - int64_t g_ipiv_rindx = scalapack_indxl2g(&l_ipiv_rindx, &nb, &myprow, &isrcproc0, &nprow); + int64_t g_ipiv_rindx = scalapack_indxl2g( &l_ipiv_rindx, &nb, &myprow, &isrcproc0, &nprow ); // assuming uniform nb from scalapack (note 1-indexing) // figure out pivots(tile-index, offset) int64_t g_ipiv_tile_indx = (g_ipiv_rindx - 1) / nb; @@ -133,5 +103,43 @@ void slate_pgesv_mixed(int n, int nrhs, scalar_t* a, int ia, int ja, int* desca, *info = 0; } +//------------------------------------------------------------------------------ +// Fortran interfaces +// Each Fortran interface calls the type generic slate wrapper. + +extern "C" { + +#define SCALAPACK_pdsgesv BLAS_FORTRAN_NAME( pdsgesv, PDSGESV ) +void SCALAPACK_pdsgesv( + blas_int const* n, blas_int* nrhs, + double* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + blas_int* ipiv, + double* B_data, blas_int const* ib, blas_int const* jb, blas_int const* descB, + double* X_data, blas_int const* ix, blas_int const* jx, blas_int const* descX, + blas_int* iter, blas_int* info ) +{ + slate_pgesv_mixed( *n, *nrhs, + A_data, *ia, *ja, descA, ipiv, + B_data, *ib, *jb, descB, + X_data, *ix, *jx, descX, iter, info ); +} + +#define SCALAPACK_pzcgesv BLAS_FORTRAN_NAME( pzcgesv, PZCGESV ) +void SCALAPACK_pzcgesv( + blas_int const* n, blas_int* nrhs, + std::complex* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + blas_int* ipiv, + std::complex* B_data, blas_int const* ib, blas_int const* jb, blas_int const* descB, + std::complex* X_data, blas_int const* ix, blas_int const* jx, blas_int const* descX, + blas_int* iter, blas_int* info ) +{ + slate_pgesv_mixed( *n, *nrhs, + A_data, *ia, *ja, descA, ipiv, + B_data, *ib, *jb, descB, + X_data, *ix, *jx, descX, iter, info ); +} + +} // extern "C" + } // namespace scalapack_api } // namespace slate diff --git a/scalapack_api/scalapack_gesvd.cc b/scalapack_api/scalapack_gesvd.cc index 7372a375..08cfd1d8 100644 --- a/scalapack_api/scalapack_gesvd.cc +++ b/scalapack_api/scalapack_gesvd.cc @@ -8,99 +8,24 @@ namespace slate { namespace scalapack_api { -// ----------------------------------------------------------------------------- - -// Required CBLACS calls -extern "C" void Cblacs_gridinfo(int context, int* np_row, int* np_col, int* my_row, int* my_col); - -// Type generic function calls the SLATE routine -template< typename scalar_t > -void slate_pgesvd(const char* jobustr, const char* jobvtstr, int m, int n, scalar_t* a, int ia, int ja, int* desca, blas::real_type* s, scalar_t* u, int iu, int ju, int* descu, scalar_t* vt, int ivt, int jvt, int* descvt, scalar_t* work, int lwork, blas::real_type* rwork, int* info); - -// ----------------------------------------------------------------------------- -// C interfaces (FORTRAN_UPPER, FORTRAN_LOWER, FORTRAN_UNDERSCORE) -// Each C interface calls the type generic slate_pher2k - -extern "C" void PSGESVD(const char* jobustr, const char* jobvtstr, int* m, int* n, float* a, int* ia, int* ja, int* desca, float* s, float* u, int* iu, int* ju, int* descu, float* vt, int* ivt, int* jvt, int* descvt, float* work, int* lwork, int* info) -{ - float dummy; - slate_pgesvd(jobustr, jobvtstr, *m, *n, a, *ia, *ja, desca, s, u, *iu, *ju, descu, vt, *ivt, *jvt, descvt, work, *lwork, &dummy, info); -} - -extern "C" void psgesvd(const char* jobustr, const char* jobvtstr, int* m, int* n, float* a, int* ia, int* ja, int* desca, float* s, float* u, int* iu, int* ju, int* descu, float* vt, int* ivt, int* jvt, int* descvt, float* work, int* lwork, int* info) -{ - float dummy; - slate_pgesvd(jobustr, jobvtstr, *m, *n, a, *ia, *ja, desca, s, u, *iu, *ju, descu, vt, *ivt, *jvt, descvt, work, *lwork, &dummy, info); -} - -extern "C" void psgesvd_(const char* jobustr, const char* jobvtstr, int* m, int* n, float* a, int* ia, int* ja, int* desca, float* s, float* u, int* iu, int* ju, int* descu, float* vt, int* ivt, int* jvt, int* descvt, float* work, int* lwork, int* info) -{ - float dummy; - slate_pgesvd(jobustr, jobvtstr, *m, *n, a, *ia, *ja, desca, s, u, *iu, *ju, descu, vt, *ivt, *jvt, descvt, work, *lwork, &dummy, info); -} - -// ----------------------------------------------------------------------------- - -extern "C" void PDGESVD(const char* jobustr, const char* jobvtstr, int* m, int* n, double* a, int* ia, int* ja, int* desca, double* s, double* u, int* iu, int* ju, int* descu, double* vt, int* ivt, int* jvt, int* descvt, double* work, int* lwork, int* info) -{ - double dummy; - slate_pgesvd(jobustr, jobvtstr, *m, *n, a, *ia, *ja, desca, s, u, *iu, *ju, descu, vt, *ivt, *jvt, descvt, work, *lwork, &dummy, info); -} - -extern "C" void pdgesvd(const char* jobustr, const char* jobvtstr, int* m, int* n, double* a, int* ia, int* ja, int* desca, double* s, double* u, int* iu, int* ju, int* descu, double* vt, int* ivt, int* jvt, int* descvt, double* work, int* lwork, int* info) -{ - double dummy; - slate_pgesvd(jobustr, jobvtstr, *m, *n, a, *ia, *ja, desca, s, u, *iu, *ju, descu, vt, *ivt, *jvt, descvt, work, *lwork, &dummy, info); -} - -extern "C" void pdgesvd_(const char* jobustr, const char* jobvtstr, int* m, int* n, double* a, int* ia, int* ja, int* desca, double* s, double* u, int* iu, int* ju, int* descu, double* vt, int* ivt, int* jvt, int* descvt, double* work, int* lwork, int* info) -{ - double dummy; - slate_pgesvd(jobustr, jobvtstr, *m, *n, a, *ia, *ja, desca, s, u, *iu, *ju, descu, vt, *ivt, *jvt, descvt, work, *lwork, &dummy, info); -} - -// ----------------------------------------------------------------------------- - -extern "C" void PCGESVD(const char* jobustr, const char* jobvtstr, int* m, int* n, std::complex* a, int* ia, int* ja, int* desca, float* s, std::complex* u, int* iu, int* ju, int* descu, std::complex* vt, int* ivt, int* jvt, int* descvt, std::complex* work, int* lwork, float* rwork, int* info) -{ - slate_pgesvd(jobustr, jobvtstr, *m, *n, a, *ia, *ja, desca, s, u, *iu, *ju, descu, vt, *ivt, *jvt, descvt, work, *lwork, rwork, info); -} - -extern "C" void pcgesvd(const char* jobustr, const char* jobvtstr, int* m, int* n, std::complex* a, int* ia, int* ja, int* desca, float* s, std::complex* u, int* iu, int* ju, int* descu, std::complex* vt, int* ivt, int* jvt, int* descvt, std::complex* work, int* lwork, float* rwork, int* info) -{ - slate_pgesvd(jobustr, jobvtstr, *m, *n, a, *ia, *ja, desca, s, u, *iu, *ju, descu, vt, *ivt, *jvt, descvt, work, *lwork, rwork, info); -} - -extern "C" void pcgesvd_(const char* jobustr, const char* jobvtstr, int* m, int* n, std::complex* a, int* ia, int* ja, int* desca, float* s, std::complex* u, int* iu, int* ju, int* descu, std::complex* vt, int* ivt, int* jvt, int* descvt, std::complex* work, int* lwork, float* rwork, int* info) -{ - slate_pgesvd(jobustr, jobvtstr, *m, *n, a, *ia, *ja, desca, s, u, *iu, *ju, descu, vt, *ivt, *jvt, descvt, work, *lwork, rwork, info); -} - -// ----------------------------------------------------------------------------- - -extern "C" void PZGESVD(const char* jobustr, const char* jobvtstr, int* m, int* n, std::complex* a, int* ia, int* ja, int* desca, float* s, std::complex* u, int* iu, int* ju, int* descu, std::complex* vt, int* ivt, int* jvt, int* descvt, std::complex* work, int* lwork, float* rwork, int* info) -{ - slate_pgesvd(jobustr, jobvtstr, *m, *n, a, *ia, *ja, desca, s, u, *iu, *ju, descu, vt, *ivt, *jvt, descvt, work, *lwork, rwork, info); -} - -extern "C" void pzgesvd_(const char* jobustr, const char* jobvtstr, int* m, int* n, std::complex* a, int* ia, int* ja, int* desca, float* s, std::complex* u, int* iu, int* ju, int* descu, std::complex* vt, int* ivt, int* jvt, int* descvt, std::complex* work, int* lwork, float* rwork, int* info) -{ - slate_pgesvd(jobustr, jobvtstr, *m, *n, a, *ia, *ja, desca, s, u, *iu, *ju, descu, vt, *ivt, *jvt, descvt, work, *lwork, rwork, info); -} - -extern "C" void pzgesvd(const char* jobustr, const char* jobvtstr, int* m, int* n, std::complex* a, int* ia, int* ja, int* desca, float* s, std::complex* u, int* iu, int* ju, int* descu, std::complex* vt, int* ivt, int* jvt, int* descvt, std::complex* work, int* lwork, float* rwork, int* info) -{ - slate_pgesvd(jobustr, jobvtstr, *m, *n, a, *ia, *ja, desca, s, u, *iu, *ju, descu, vt, *ivt, *jvt, descvt, work, *lwork, rwork, info); -} - -// ----------------------------------------------------------------------------- -template< typename scalar_t > -void slate_pgesvd(const char* jobustr, const char* jobvtstr, int m, int n, scalar_t* a, int ia, int ja, int* desca, blas::real_type* s, scalar_t* u, int iu, int ju, int* descu, scalar_t* vt, int ivt, int jvt, int* descvt, scalar_t* work, int lwork, blas::real_type* rwork, int* info) +//------------------------------------------------------------------------------ +/// SLATE ScaLAPACK wrapper sets up SLATE matrices from ScaLAPACK descriptors +/// and calls SLATE. +template +void slate_pgesvd( + const char* jobu_str, const char* jobvt_str, blas_int m, blas_int n, + scalar_t* A_data, blas_int ia, blas_int ja, blas_int const* descA, + blas::real_type* Sigma, + scalar_t* U_data, blas_int iu, blas_int ju, blas_int const* descU, + scalar_t* VT_data, blas_int ivt, blas_int jvt, blas_int const* descVT, + scalar_t* work, blas_int lwork, + blas::real_type* rwork, + blas_int* info ) { Job jobu{}; Job jobvt{}; - from_string( std::string( 1, jobustr[0] ), &jobu ); - from_string( std::string( 1, jobvtstr[0] ), &jobvt ); + from_string( std::string( 1, jobu_str[0] ), &jobu ); + from_string( std::string( 1, jobvt_str[0] ), &jobvt ); slate::Target target = TargetConfig::value(); int verbose = VerboseConfig::value(); @@ -112,8 +37,8 @@ void slate_pgesvd(const char* jobustr, const char* jobvtstr, int m, int n, scala // todo: extract the real info from gesvd *info = 0; - int nprow, npcol, myprow, mypcol; - Cblacs_gridinfo(desc_CTXT(desca), &nprow, &npcol, &myprow, &mypcol); + blas_int nprow, npcol, myprow, mypcol; + Cblacs_gridinfo( desc_ctxt( descA ), &nprow, &npcol, &myprow, &mypcol ); if (verbose && myprow == 0 && mypcol == 0) logprintf("%s\n", "gesvd"); @@ -125,7 +50,7 @@ void slate_pgesvd(const char* jobustr, const char* jobvtstr, int m, int n, scala } // Matrix sizes - int64_t min_mn = std::min(m, n); + int64_t min_mn = std::min( m, n ); int64_t Am = m; int64_t An = n; int64_t Um = m; @@ -134,21 +59,30 @@ void slate_pgesvd(const char* jobustr, const char* jobvtstr, int m, int n, scala int64_t VTn = n; // create SLATE matrices from the ScaLAPACK layouts - auto A = slate::Matrix::fromScaLAPACK(desc_M(desca), desc_N(desca), a, desc_LLD(desca), desc_MB(desca), desc_NB(desca), grid_order, nprow, npcol, MPI_COMM_WORLD); - A = slate_scalapack_submatrix(Am, An, A, ia, ja, desca); + auto A = slate::Matrix::fromScaLAPACK( + desc_m( descA ), desc_n( descA ), A_data, desc_lld( descA ), + desc_mb( descA ), desc_nb( descA ), + grid_order, nprow, npcol, MPI_COMM_WORLD ); + A = slate_scalapack_submatrix( Am, An, A, ia, ja, descA ); slate::Matrix U; if (jobu == lapack::Job::Vec) { - Cblacs_gridinfo(desc_CTXT(descu), &nprow, &npcol, &myprow, &mypcol); - U = slate::Matrix::fromScaLAPACK(desc_M(descu), desc_N(descu), u, desc_LLD(descu), desc_MB(descu), desc_NB(descu), grid_order, nprow, npcol, MPI_COMM_WORLD); - U = slate_scalapack_submatrix(Um, Un, U, iu, ju, descu); + Cblacs_gridinfo( desc_ctxt( descU ), &nprow, &npcol, &myprow, &mypcol ); + U = slate::Matrix::fromScaLAPACK( + desc_m( descU ), desc_n( descU ), U_data, desc_lld( descU ), + desc_mb( descU ), desc_nb( descU ), + grid_order, nprow, npcol, MPI_COMM_WORLD ); + U = slate_scalapack_submatrix( Um, Un, U, iu, ju, descU ); } slate::Matrix VT; if (jobvt == lapack::Job::Vec) { - Cblacs_gridinfo(desc_CTXT(descvt), &nprow, &npcol, &myprow, &mypcol); - VT = slate::Matrix::fromScaLAPACK(desc_M(descvt), desc_N(descvt), vt, desc_LLD(descvt), desc_MB(descvt), desc_NB(descvt), grid_order, nprow, npcol, MPI_COMM_WORLD); - VT = slate_scalapack_submatrix(VTm, VTn, VT, ivt, jvt, descvt); + Cblacs_gridinfo( desc_ctxt( descVT ), &nprow, &npcol, &myprow, &mypcol ); + VT = slate::Matrix::fromScaLAPACK( + desc_m( descVT ), desc_n( descVT ), VT_data, desc_lld( descVT ), + desc_mb( descVT ), desc_nb( descVT ), + grid_order, nprow, npcol, MPI_COMM_WORLD ); + VT = slate_scalapack_submatrix( VTm, VTn, VT, ivt, jvt, descVT ); } std::vector< blas::real_type > Sigma_( n ); @@ -160,8 +94,96 @@ void slate_pgesvd(const char* jobustr, const char* jobvtstr, int m, int n, scala {slate::Option::InnerBlocking, ib} }); - std::copy(Sigma_.begin(), Sigma_.end(), s); + std::copy( Sigma_.begin(), Sigma_.end(), Sigma ); +} + +//------------------------------------------------------------------------------ +// Fortran interfaces +// Each Fortran interface calls the type generic slate wrapper. + +extern "C" { + +#define SCALAPACK_psgesvd BLAS_FORTRAN_NAME( psgesvd, PSGESVD ) +void SCALAPACK_psgesvd( + const char* jobu_str, const char* jobvt_str, blas_int const* m, blas_int const* n, + float* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + float* Sigma, + float* U_data, blas_int const* iu, blas_int const* ju, blas_int const* descU, + float* VT_data, blas_int const* ivt, blas_int const* jvt, blas_int const* descVT, + float* work, blas_int const* lwork, + blas_int* info ) +{ + float dummy; + slate_pgesvd( + jobu_str, jobvt_str, *m, *n, + A_data, *ia, *ja, descA, + Sigma, + U_data, *iu, *ju, descU, + VT_data, *ivt, *jvt, descVT, + work, *lwork, &dummy, info ); +} + +#define SCALAPACK_pdgesvd BLAS_FORTRAN_NAME( pdgesvd, PDGESVD ) +void SCALAPACK_pdgesvd( + const char* jobu_str, const char* jobvt_str, blas_int const* m, blas_int const* n, + double* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + double* Sigma, + double* U_data, blas_int const* iu, blas_int const* ju, blas_int const* descU, + double* VT_data, blas_int const* ivt, blas_int const* jvt, blas_int const* descVT, + double* work, blas_int const* lwork, + blas_int* info ) +{ + double dummy; + slate_pgesvd( + jobu_str, jobvt_str, *m, *n, + A_data, *ia, *ja, descA, + Sigma, + U_data, *iu, *ju, descU, + VT_data, *ivt, *jvt, descVT, + work, *lwork, &dummy, info ); +} + +#define SCALAPACK_pcgesvd BLAS_FORTRAN_NAME( pcgesvd, PCGESVD ) +void SCALAPACK_pcgesvd( + const char* jobu_str, const char* jobvt_str, blas_int const* m, blas_int const* n, + std::complex* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + float* Sigma, + std::complex* U_data, blas_int const* iu, blas_int const* ju, blas_int const* descU, + std::complex* VT_data, blas_int const* ivt, blas_int const* jvt, blas_int const* descVT, + std::complex* work, blas_int const* lwork, + float* rwork, + blas_int* info ) +{ + slate_pgesvd( + jobu_str, jobvt_str, *m, *n, + A_data, *ia, *ja, descA, + Sigma, + U_data, *iu, *ju, descU, + VT_data, *ivt, *jvt, descVT, + work, *lwork, rwork, info ); +} + +#define SCALAPACK_pzgesvd BLAS_FORTRAN_NAME( pzgesvd, PZGESVD ) +void SCALAPACK_pzgesvd( + const char* jobu_str, const char* jobvt_str, blas_int const* m, blas_int const* n, + std::complex* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + double* Sigma, + std::complex* U_data, blas_int const* iu, blas_int const* ju, blas_int const* descU, + std::complex* VT_data, blas_int const* ivt, blas_int const* jvt, blas_int const* descVT, + std::complex* work, blas_int const* lwork, + double* rwork, + blas_int* info ) +{ + slate_pgesvd( + jobu_str, jobvt_str, *m, *n, + A_data, *ia, *ja, descA, + Sigma, + U_data, *iu, *ju, descU, + VT_data, *ivt, *jvt, descVT, + work, *lwork, rwork, info ); } +} // extern "C" + } // namespace scalapack_api } // namespace slate diff --git a/scalapack_api/scalapack_getrf.cc b/scalapack_api/scalapack_getrf.cc index 56f7d42e..646f531b 100644 --- a/scalapack_api/scalapack_getrf.cc +++ b/scalapack_api/scalapack_getrf.cc @@ -8,88 +8,15 @@ namespace slate { namespace scalapack_api { -// ----------------------------------------------------------------------------- - -// Required CBLACS calls -extern "C" void Cblacs_gridinfo(int context, int* np_row, int* np_col, int* my_row, int* my_col); - -// Type generic function calls the SLATE routine -template< typename scalar_t > -void slate_pgetrf(int m, int n, scalar_t* a, int ia, int ja, int* desca, int* ipiv, int* info); - -// ----------------------------------------------------------------------------- -// C interfaces (FORTRAN_UPPER, FORTRAN_LOWER, FORTRAN_UNDERSCORE) -// Each C interface calls the type generic slate_pher2k - -extern "C" void PSGETRF(int* m, int* n, float* a, int* ia, int* ja, int* desca, int* ipiv, int* info) -{ - slate_pgetrf(*m, *n, a, *ia, *ja, desca, ipiv, info); -} - -extern "C" void psgetrf(int* m, int* n, float* a, int* ia, int* ja, int* desca, int* ipiv, int* info) -{ - slate_pgetrf(*m, *n, a, *ia, *ja, desca, ipiv, info); -} - -extern "C" void psgetrf_(int* m, int* n, float* a, int* ia, int* ja, int* desca, int* ipiv, int* info) -{ - slate_pgetrf(*m, *n, a, *ia, *ja, desca, ipiv, info); -} - -// ----------------------------------------------------------------------------- - -extern "C" void PDGETRF(int* m, int* n, double* a, int* ia, int* ja, int* desca, int* ipiv, int* info) -{ - slate_pgetrf(*m, *n, a, *ia, *ja, desca, ipiv, info); -} - -extern "C" void pdgetrf(int* m, int* n, double* a, int* ia, int* ja, int* desca, int* ipiv, int* info) -{ - slate_pgetrf(*m, *n, a, *ia, *ja, desca, ipiv, info); -} - -extern "C" void pdgetrf_(int* m, int* n, double* a, int* ia, int* ja, int* desca, int* ipiv, int* info) -{ - slate_pgetrf(*m, *n, a, *ia, *ja, desca, ipiv, info); -} - -// ----------------------------------------------------------------------------- - -extern "C" void PCGETRF(int* m, int* n, std::complex* a, int* ia, int* ja, int* desca, int* ipiv, int* info) -{ - slate_pgetrf(*m, *n, a, *ia, *ja, desca, ipiv, info); -} - -extern "C" void pcgetrf(int* m, int* n, std::complex* a, int* ia, int* ja, int* desca, int* ipiv, int* info) -{ - slate_pgetrf(*m, *n, a, *ia, *ja, desca, ipiv, info); -} - -extern "C" void pcgetrf_(int* m, int* n, std::complex* a, int* ia, int* ja, int* desca, int* ipiv, int* info) -{ - slate_pgetrf(*m, *n, a, *ia, *ja, desca, ipiv, info); -} - -// ----------------------------------------------------------------------------- - -extern "C" void PZGETRF(int* m, int* n, std::complex* a, int* ia, int* ja, int* desca, int* ipiv, int* info) -{ - slate_pgetrf(*m, *n, a, *ia, *ja, desca, ipiv, info); -} - -extern "C" void pzgetrf(int* m, int* n, std::complex* a, int* ia, int* ja, int* desca, int* ipiv, int* info) -{ - slate_pgetrf(*m, *n, a, *ia, *ja, desca, ipiv, info); -} - -extern "C" void pzgetrf_(int* m, int* n, std::complex* a, int* ia, int* ja, int* desca, int* ipiv, int* info) -{ - slate_pgetrf(*m, *n, a, *ia, *ja, desca, ipiv, info); -} - -// ----------------------------------------------------------------------------- -template< typename scalar_t > -void slate_pgetrf(int m, int n, scalar_t* a, int ia, int ja, int* desca, int* ipiv, int* info) +//------------------------------------------------------------------------------ +/// SLATE ScaLAPACK wrapper sets up SLATE matrices from ScaLAPACK descriptors +/// and calls SLATE. +template +void slate_pgetrf( + blas_int m, blas_int n, + scalar_t* A_data, blas_int ia, blas_int ja, blas_int const* descA, + blas_int* ipiv, + blas_int* info ) { slate::Target target = TargetConfig::value(); int verbose = VerboseConfig::value(); @@ -104,15 +31,18 @@ void slate_pgetrf(int m, int n, scalar_t* a, int ia, int ja, int* desca, int* ip slate::Pivots pivots; // create SLATE matrices from the ScaLAPACK layouts - int nprow, npcol, myprow, mypcol; - Cblacs_gridinfo(desc_CTXT(desca), &nprow, &npcol, &myprow, &mypcol); - auto A = slate::Matrix::fromScaLAPACK(desc_M(desca), desc_N(desca), a, desc_LLD(desca), desc_MB(desca), desc_NB(desca), grid_order, nprow, npcol, MPI_COMM_WORLD); - A = slate_scalapack_submatrix(Am, An, A, ia, ja, desca); + blas_int nprow, npcol, myprow, mypcol; + Cblacs_gridinfo( desc_ctxt( descA ), &nprow, &npcol, &myprow, &mypcol ); + auto A = slate::Matrix::fromScaLAPACK( + desc_m( descA ), desc_n( descA ), A_data, desc_lld( descA ), + desc_mb( descA ), desc_nb( descA ), + grid_order, nprow, npcol, MPI_COMM_WORLD ); + A = slate_scalapack_submatrix( Am, An, A, ia, ja, descA ); if (verbose && myprow == 0 && mypcol == 0) logprintf("%s\n", "getrf"); - slate::getrf(A, pivots, { + slate::getrf( A, pivots, { {slate::Option::Lookahead, lookahead}, {slate::Option::Target, target}, {slate::Option::MaxPanelThreads, panel_threads}, @@ -121,14 +51,14 @@ void slate_pgetrf(int m, int n, scalar_t* a, int ia, int ja, int* desca, int* ip // Extract pivots from SLATE's global Pivots structure into ScaLAPACK local ipiv array { - int isrcproc0 = 0; - int nb = desc_MB(desca); // ScaLAPACK style fixed nb - int64_t l_numrows = scalapack_numroc(An, nb, myprow, isrcproc0, nprow); + blas_int isrcproc0 = 0; + blas_int nb = desc_mb( descA ); // ScaLAPACK style fixed nb + int64_t l_numrows = scalapack_numroc( An, nb, myprow, isrcproc0, nprow ); // l_ipiv_rindx local ipiv row index (Scalapack 1-index) // for each local ipiv entry, find corresponding local-pivot and swap-pivot - for (int l_ipiv_rindx=1; l_ipiv_rindx <= l_numrows; ++l_ipiv_rindx) { + for (blas_int l_ipiv_rindx=1; l_ipiv_rindx <= l_numrows; ++l_ipiv_rindx) { // for ipiv index, convert to global indexing - int64_t g_ipiv_rindx = scalapack_indxl2g(&l_ipiv_rindx, &nb, &myprow, &isrcproc0, &nprow); + int64_t g_ipiv_rindx = scalapack_indxl2g( &l_ipiv_rindx, &nb, &myprow, &isrcproc0, &nprow ); // assuming uniform nb from scalapack (note 1-indexing) // figure out pivots(tile-index, offset) int64_t g_ipiv_tile_indx = (g_ipiv_rindx - 1) / nb; @@ -148,5 +78,65 @@ void slate_pgetrf(int m, int n, scalar_t* a, int ia, int ja, int* desca, int* ip *info = 0; } +//------------------------------------------------------------------------------ +// Fortran interfaces +// Each Fortran interface calls the type generic slate wrapper. + +extern "C" { + +#define SCALAPACK_psgetrf BLAS_FORTRAN_NAME( psgetrf, PSGETRF ) +void SCALAPACK_psgetrf( + blas_int const* m, blas_int const* n, + float* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + blas_int* ipiv, + blas_int* info ) +{ + slate_pgetrf( + *m, *n, + A_data, *ia, *ja, descA, + ipiv, info ); +} + +#define SCALAPACK_pdgetrf BLAS_FORTRAN_NAME( pdgetrf, PDGETRF ) +void SCALAPACK_pdgetrf( + blas_int const* m, blas_int const* n, + double* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + blas_int* ipiv, + blas_int* info ) +{ + slate_pgetrf( + *m, *n, + A_data, *ia, *ja, descA, + ipiv, info ); +} + +#define SCALAPACK_pcgetrf BLAS_FORTRAN_NAME( pcgetrf, PCGETRF ) +void SCALAPACK_pcgetrf( + blas_int const* m, blas_int const* n, + std::complex* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + blas_int* ipiv, + blas_int* info ) +{ + slate_pgetrf( + *m, *n, + A_data, *ia, *ja, descA, + ipiv, info ); +} + +#define SCALAPACK_pzgetrf BLAS_FORTRAN_NAME( pzgetrf, PZGETRF ) +void SCALAPACK_pzgetrf( + blas_int const* m, blas_int const* n, + std::complex* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + blas_int* ipiv, + blas_int* info ) +{ + slate_pgetrf( + *m, *n, + A_data, *ia, *ja, descA, + ipiv, info ); +} + +} // extern "C" + } // namespace scalapack_api } // namespace slate diff --git a/scalapack_api/scalapack_getri.cc b/scalapack_api/scalapack_getri.cc index 787cee0a..19bc4c3b 100644 --- a/scalapack_api/scalapack_getri.cc +++ b/scalapack_api/scalapack_getri.cc @@ -8,90 +8,19 @@ namespace slate { namespace scalapack_api { -// ----------------------------------------------------------------------------- - -// Required CBLACS calls -extern "C" void Cblacs_gridinfo(int context, int* np_row, int* np_col, int* my_row, int* my_col); - -// Type generic function calls the SLATE routine -template< typename scalar_t > -void slate_pgetri(int n, scalar_t* a, int ia, int ja, int* desca, int* ipiv, scalar_t* work, int lwork, int* iwork, int liwork, int* info); - -// ----------------------------------------------------------------------------- -// C interfaces (FORTRAN_UPPER, FORTRAN_LOWER, FORTRAN_UNDERSCORE) -// Each C interface calls the type generic slate_pher2k - -extern "C" void PSGETRI(int* n, float* a, int* ia, int* ja, int* desca, int* ipiv, float* work, int* lwork, int* iwork, int* liwork, int* info) -{ - slate_pgetri(*n, a, *ia, *ja, desca, ipiv, work, *lwork, iwork, *liwork, info); -} - -extern "C" void psgetri(int* n, float* a, int* ia, int* ja, int* desca, int* ipiv, float* work, int* lwork, int* iwork, int* liwork, int* info) -{ - slate_pgetri(*n, a, *ia, *ja, desca, ipiv, work, *lwork, iwork, *liwork, info); -} - -extern "C" void psgetri_(int* n, float* a, int* ia, int* ja, int* desca, int* ipiv, float* work, int* lwork, int* iwork, int* liwork, int* info) -{ - slate_pgetri(*n, a, *ia, *ja, desca, ipiv, work, *lwork, iwork, *liwork, info); -} - -// ----------------------------------------------------------------------------- - -extern "C" void PDGETRI(int* n, double* a, int* ia, int* ja, int* desca, int* ipiv, double* work, int* lwork, int* iwork, int* liwork, int* info) -{ - slate_pgetri(*n, a, *ia, *ja, desca, ipiv, work, *lwork, iwork, *liwork, info); -} - -extern "C" void pdgetri(int* n, double* a, int* ia, int* ja, int* desca, int* ipiv, double* work, int* lwork, int* iwork, int* liwork, int* info) -{ - slate_pgetri(*n, a, *ia, *ja, desca, ipiv, work, *lwork, iwork, *liwork, info); -} - -extern "C" void pdgetri_(int* n, double* a, int* ia, int* ja, int* desca, int* ipiv, double* work, int* lwork, int* iwork, int* liwork, int* info) -{ - slate_pgetri(*n, a, *ia, *ja, desca, ipiv, work, *lwork, iwork, *liwork, info); -} - -// ----------------------------------------------------------------------------- - -extern "C" void PCGETRI(int* n, std::complex* a, int* ia, int* ja, int* desca, int* ipiv, std::complex* work, int* lwork, int* iwork, int* liwork, int* info) -{ - slate_pgetri(*n, a, *ia, *ja, desca, ipiv, work, *lwork, iwork, *liwork, info); -} - -extern "C" void pcgetri(int* n, std::complex* a, int* ia, int* ja, int* desca, int* ipiv, std::complex* work, int* lwork, int* iwork, int* liwork, int* info) -{ - slate_pgetri(*n, a, *ia, *ja, desca, ipiv, work, *lwork, iwork, *liwork, info); -} - -extern "C" void pcgetri_(int* n, std::complex* a, int* ia, int* ja, int* desca, int* ipiv, std::complex* work, int* lwork, int* iwork, int* liwork, int* info) -{ - slate_pgetri(*n, a, *ia, *ja, desca, ipiv, work, *lwork, iwork, *liwork, info); -} - -// ----------------------------------------------------------------------------- - -extern "C" void PZGETRI(int* n, std::complex* a, int* ia, int* ja, int* desca, int* ipiv, std::complex* work, int* lwork, int* iwork, int* liwork, int* info) -{ - slate_pgetri(*n, a, *ia, *ja, desca, ipiv, work, *lwork, iwork, *liwork, info); -} - -extern "C" void pzgetri(int* n, std::complex* a, int* ia, int* ja, int* desca, int* ipiv, std::complex* work, int* lwork, int* iwork, int* liwork, int* info) +//------------------------------------------------------------------------------ +/// SLATE ScaLAPACK wrapper sets up SLATE matrices from ScaLAPACK descriptors +/// and calls SLATE. +template +void slate_pgetri( + blas_int n, + scalar_t* A_data, blas_int ia, blas_int ja, blas_int const* descA, + blas_int* ipiv, + scalar_t* work, blas_int lwork, + blas_int* iwork, blas_int liwork, + blas_int* info ) { - slate_pgetri(*n, a, *ia, *ja, desca, ipiv, work, *lwork, iwork, *liwork, info); -} - -extern "C" void pzgetri_(int* n, std::complex* a, int* ia, int* ja, int* desca, int* ipiv, std::complex* work, int* lwork, int* iwork, int* liwork, int* info) -{ - slate_pgetri(*n, a, *ia, *ja, desca, ipiv, work, *lwork, iwork, *liwork, info); -} - -// ----------------------------------------------------------------------------- -template< typename scalar_t > -void slate_pgetri(int n, scalar_t* a, int ia, int ja, int* desca, int* ipiv, scalar_t* work, int lwork, int* iwork, int liwork, int* info) -{ - slate::Target target = TargetConfig::value(); + slate::Target target = TargetConfig::value( ); int verbose = VerboseConfig::value(); int64_t lookahead = LookaheadConfig::value(); int64_t panel_threads = PanelThreadsConfig::value(); @@ -106,10 +35,13 @@ void slate_pgetri(int n, scalar_t* a, int ia, int ja, int* desca, int* ipiv, sca }; // create SLATE matrices from the ScaLAPACK layouts - int nprow, npcol, myprow, mypcol; - Cblacs_gridinfo(desc_CTXT(desca), &nprow, &npcol, &myprow, &mypcol); - auto A = slate::Matrix::fromScaLAPACK(desc_M(desca), desc_N(desca), a, desc_LLD(desca), desc_MB(desca), desc_NB(desca), grid_order, nprow, npcol, MPI_COMM_WORLD); - A = slate_scalapack_submatrix(n, n, A, ia, ja, desca); + blas_int nprow, npcol, myprow, mypcol; + Cblacs_gridinfo( desc_ctxt( descA ), &nprow, &npcol, &myprow, &mypcol ); + auto A = slate::Matrix::fromScaLAPACK( + desc_m( descA ), desc_n( descA ), A_data, desc_lld( descA ), + desc_mb( descA ), desc_nb( descA ), + grid_order, nprow, npcol, MPI_COMM_WORLD ); + A = slate_scalapack_submatrix( n, n, A, ia, ja, descA ); if (verbose && myprow == 0 && mypcol == 0) logprintf("%s\n", "getri"); @@ -120,22 +52,22 @@ void slate_pgetri(int n, scalar_t* a, int ia, int ja, int* desca, int* ipiv, sca // copy pivots from ScaLAPACK local ipiv array to SLATE global Pivots structure { // allocate pivots - int64_t min_mt_nt = std::min(A.mt(), A.nt()); - pivots.resize(min_mt_nt); + int64_t min_mt_nt = std::min( A.mt(), A.nt() ); + pivots.resize( min_mt_nt ); for (int64_t k = 0; k < min_mt_nt; ++k) { - int64_t diag_len = std::min(A.tileMb(k), A.tileNb(k)); - pivots.at(k).resize(diag_len); + int64_t diag_len = std::min( A.tileMb( k ), A.tileNb( k ) ); + pivots.at( k ).resize( diag_len ); } // transfer local ipiv to local part of pivots - int isrcproc0 = 0; - int nb = desc_MB(desca); // ScaLAPACK style fixed nb - int64_t l_numrows = scalapack_numroc(n, nb, myprow, isrcproc0, nprow); // local number of rows + blas_int isrcproc0 = 0; + blas_int nb = desc_mb( descA ); // ScaLAPACK style fixed nb + int64_t l_numrows = scalapack_numroc( n, nb, myprow, isrcproc0, nprow ); // local number of rows // l_rindx local row index (Scalapack 1-index) // for each local ipiv entry, find corresponding local-pivot information and swap-pivot information - for (int l_ipiv_rindx=1; l_ipiv_rindx <= l_numrows; ++l_ipiv_rindx) { + for (blas_int l_ipiv_rindx=1; l_ipiv_rindx <= l_numrows; ++l_ipiv_rindx) { // for local ipiv index, convert to global indexing - int64_t g_ipiv_rindx = scalapack_indxl2g(&l_ipiv_rindx, &nb, &myprow, &isrcproc0, &nprow); + int64_t g_ipiv_rindx = scalapack_indxl2g( &l_ipiv_rindx, &nb, &myprow, &isrcproc0, &nprow ); // assuming uniform nb from scalapack (note 1-indexing), find global tile, offset int64_t g_ipiv_tile_indx = (g_ipiv_rindx - 1) / nb; int64_t g_ipiv_tile_offset = (g_ipiv_rindx -1) % nb; @@ -146,7 +78,7 @@ void slate_pgetri(int n, scalar_t* a, int ia, int ja, int* desca, int* ipiv, sca int64_t tileIndexSwap = ((ipiv[l_ipiv_rindx - 1] - 1) / nb) - g_ipiv_tile_indx; int64_t elementOffsetSwap = (ipiv[l_ipiv_rindx - 1] - 1) % nb; // in the local pivot object, assign swap information - pivref = Pivot(tileIndexSwap, elementOffsetSwap); + pivref = Pivot( tileIndexSwap, elementOffsetSwap ); // if (verbose) { // printf("[%d,%d] getrs ipiv[%lld=%lld]=%lld -> pivots[%lld][%lld]=(%lld,%lld)\n", // myprow, mypcol, @@ -155,22 +87,90 @@ void slate_pgetri(int n, scalar_t* a, int ia, int ja, int* desca, int* ipiv, sca // llong( g_ipiv_tile_indx ), llong( g_ipiv_tile_offset ), // llong( tileIndexSwap ), llong( elementOffsetSwap )); // } - // fflush(0); + // fflush( 0 ); } // broadcast local pivot information to all processes for (int64_t k = 0; k < min_mt_nt; ++k) { MPI_Bcast(pivots.at(k).data(), sizeof(Pivot)*pivots.at(k).size(), - MPI_BYTE, A.tileRank(k, k), A.mpiComm()); + MPI_BYTE, A.tileRank( k, k ), A.mpiComm() ); } } - slate::getri(A, pivots, opts); + slate::getri( A, pivots, opts); // todo: extract the real info from getri *info = 0; } +//------------------------------------------------------------------------------ +// Fortran interfaces +// Each Fortran interface calls the type generic slate wrapper. + +extern "C" { + +#define SCALAPACK_psgetri BLAS_FORTRAN_NAME( psgetri, PSGETRI ) +void SCALAPACK_psgetri( + blas_int const* n, + float* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + blas_int* ipiv, + float* work, blas_int const* lwork, + blas_int* iwork, blas_int const* liwork, + blas_int* info ) +{ + slate_pgetri( + *n, + A_data, *ia, *ja, descA, + ipiv, work, *lwork, iwork, *liwork, info ); +} + +#define SCALAPACK_pdgetri BLAS_FORTRAN_NAME( pdgetri, PDGETRI ) +void SCALAPACK_pdgetri( + blas_int const* n, + double* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + blas_int* ipiv, + double* work, blas_int const* lwork, + blas_int* iwork, blas_int const* liwork, + blas_int* info ) +{ + slate_pgetri( + *n, + A_data, *ia, *ja, descA, + ipiv, work, *lwork, iwork, *liwork, info ); +} + +#define SCALAPACK_pcgetri BLAS_FORTRAN_NAME( pcgetri, PCGETRI ) +void SCALAPACK_pcgetri( + blas_int const* n, + std::complex* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + blas_int* ipiv, + std::complex* work, blas_int const* lwork, + blas_int* iwork, blas_int const* liwork, + blas_int* info ) +{ + slate_pgetri( + *n, + A_data, *ia, *ja, descA, + ipiv, work, *lwork, iwork, *liwork, info ); +} + +#define SCALAPACK_pzgetri BLAS_FORTRAN_NAME( pzgetri, PZGETRI ) +void SCALAPACK_pzgetri( + blas_int const* n, + std::complex* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + blas_int* ipiv, + std::complex* work, blas_int const* lwork, + blas_int* iwork, blas_int const* liwork, + blas_int* info ) +{ + slate_pgetri( + *n, + A_data, *ia, *ja, descA, + ipiv, work, *lwork, iwork, *liwork, info ); +} + +} // extern "C" + } // namespace scalapack_api } // namespace slate diff --git a/scalapack_api/scalapack_getrs.cc b/scalapack_api/scalapack_getrs.cc index 7b96a987..b4e6c26b 100644 --- a/scalapack_api/scalapack_getrs.cc +++ b/scalapack_api/scalapack_getrs.cc @@ -8,88 +8,16 @@ namespace slate { namespace scalapack_api { -// ----------------------------------------------------------------------------- - -// Required CBLACS calls -extern "C" void Cblacs_gridinfo(int context, int* np_row, int* np_col, int* my_row, int* my_col); - -// Type generic function calls the SLATE routine -template< typename scalar_t > -void slate_pgetrs(const char* transstr, int n, int nrhs, scalar_t* a, int ia, int ja, int* desca, int* ipiv, scalar_t* b, int ib, int jb, int* descb, int* info); - -// ----------------------------------------------------------------------------- -// C interfaces (FORTRAN_UPPER, FORTRAN_LOWER, FORTRAN_UNDERSCORE) -// Each C interface calls the type generic slate_pher2k - -extern "C" void PSGETRS(const char* trans, int* n, int* nrhs, float* a, int* ia, int* ja, int* desca, int* ipiv, float* b, int* ib, int* jb, int* descb, int* info) -{ - slate_pgetrs(trans, *n, *nrhs, a, *ia, *ja, desca, ipiv, b, *ib, *jb, descb, info); -} - -extern "C" void psgetrs(const char* trans, int* n, int* nrhs, float* a, int* ia, int* ja, int* desca, int* ipiv, float* b, int* ib, int* jb, int* descb, int* info) -{ - slate_pgetrs(trans, *n, *nrhs, a, *ia, *ja, desca, ipiv, b, *ib, *jb, descb, info); -} - -extern "C" void psgetrs_(const char* trans, int* n, int* nrhs, float* a, int* ia, int* ja, int* desca, int* ipiv, float* b, int* ib, int* jb, int* descb, int* info) -{ - slate_pgetrs(trans, *n, *nrhs, a, *ia, *ja, desca, ipiv, b, *ib, *jb, descb, info); -} - -// ----------------------------------------------------------------------------- - -extern "C" void PDGETRS(const char* trans, int* n, int* nrhs, double* a, int* ia, int* ja, int* desca, int* ipiv, double* b, int* ib, int* jb, int* descb, int* info) -{ - slate_pgetrs(trans, *n, *nrhs, a, *ia, *ja, desca, ipiv, b, *ib, *jb, descb, info); -} - -extern "C" void pdgetrs(const char* trans, int* n, int* nrhs, double* a, int* ia, int* ja, int* desca, int* ipiv, double* b, int* ib, int* jb, int* descb, int* info) -{ - slate_pgetrs(trans, *n, *nrhs, a, *ia, *ja, desca, ipiv, b, *ib, *jb, descb, info); -} - -extern "C" void pdgetrs_(const char* trans, int* n, int* nrhs, double* a, int* ia, int* ja, int* desca, int* ipiv, double* b, int* ib, int* jb, int* descb, int* info) -{ - slate_pgetrs(trans, *n, *nrhs, a, *ia, *ja, desca, ipiv, b, *ib, *jb, descb, info); -} - -// ----------------------------------------------------------------------------- - -extern "C" void PCGETRS(const char* trans, int* n, int* nrhs, std::complex* a, int* ia, int* ja, int* desca, int* ipiv, std::complex* b, int* ib, int* jb, int* descb, int* info) -{ - slate_pgetrs(trans, *n, *nrhs, a, *ia, *ja, desca, ipiv, b, *ib, *jb, descb, info); -} - -extern "C" void pcgetrs(const char* trans, int* n, int* nrhs, std::complex* a, int* ia, int* ja, int* desca, int* ipiv, std::complex* b, int* ib, int* jb, int* descb, int* info) -{ - slate_pgetrs(trans, *n, *nrhs, a, *ia, *ja, desca, ipiv, b, *ib, *jb, descb, info); -} - -extern "C" void pcgetrs_(const char* trans, int* n, int* nrhs, std::complex* a, int* ia, int* ja, int* desca, int* ipiv, std::complex* b, int* ib, int* jb, int* descb, int* info) -{ - slate_pgetrs(trans, *n, *nrhs, a, *ia, *ja, desca, ipiv, b, *ib, *jb, descb, info); -} - -// ----------------------------------------------------------------------------- - -extern "C" void PZGETRS(const char* trans, int* n, int* nrhs, std::complex* a, int* ia, int* ja, int* desca, int* ipiv, std::complex* b, int* ib, int* jb, int* descb, int* info) -{ - slate_pgetrs(trans, *n, *nrhs, a, *ia, *ja, desca, ipiv, b, *ib, *jb, descb, info); -} - -extern "C" void pzgetrs(const char* trans, int* n, int* nrhs, std::complex* a, int* ia, int* ja, int* desca, int* ipiv, std::complex* b, int* ib, int* jb, int* descb, int* info) -{ - slate_pgetrs(trans, *n, *nrhs, a, *ia, *ja, desca, ipiv, b, *ib, *jb, descb, info); -} - -extern "C" void pzgetrs_(const char* trans, int* n, int* nrhs, std::complex* a, int* ia, int* ja, int* desca, int* ipiv, std::complex* b, int* ib, int* jb, int* descb, int* info) -{ - slate_pgetrs(trans, *n, *nrhs, a, *ia, *ja, desca, ipiv, b, *ib, *jb, descb, info); -} - -// ----------------------------------------------------------------------------- -template< typename scalar_t > -void slate_pgetrs(const char* transstr, int n, int nrhs, scalar_t* a, int ia, int ja, int* desca, int* ipiv, scalar_t* b, int ib, int jb, int* descb, int* info) +//------------------------------------------------------------------------------ +/// SLATE ScaLAPACK wrapper sets up SLATE matrices from ScaLAPACK descriptors +/// and calls SLATE. +template +void slate_pgetrs( + const char* trans_str, blas_int n, blas_int nrhs, + scalar_t* A_data, blas_int ia, blas_int ja, blas_int const* descA, + blas_int* ipiv, + scalar_t* B_data, blas_int ib, blas_int jb, blas_int const* descB, + blas_int* info ) { slate::Target target = TargetConfig::value(); int verbose = VerboseConfig::value(); @@ -102,7 +30,7 @@ void slate_pgetrs(const char* transstr, int n, int nrhs, scalar_t* a, int ia, in }; Op trans{}; - from_string( std::string( 1, transstr[0] ), &trans ); + from_string( std::string( 1, trans_str[0] ), &trans ); // Matrix sizes int64_t Am = n; @@ -111,14 +39,20 @@ void slate_pgetrs(const char* transstr, int n, int nrhs, scalar_t* a, int ia, in int64_t Bn = nrhs; // create SLATE matrices from the ScaLAPACK layouts - int nprow, npcol, myprow, mypcol; - Cblacs_gridinfo(desc_CTXT(desca), &nprow, &npcol, &myprow, &mypcol); - auto A = slate::Matrix::fromScaLAPACK(desc_M(desca), desc_N(desca), a, desc_LLD(desca), desc_MB(desca), desc_NB(desca), grid_order, nprow, npcol, MPI_COMM_WORLD); - A = slate_scalapack_submatrix(Am, An, A, ia, ja, desca); - - Cblacs_gridinfo(desc_CTXT(descb), &nprow, &npcol, &myprow, &mypcol); - auto B = slate::Matrix::fromScaLAPACK(desc_M(descb), desc_N(descb), b, desc_LLD(descb), desc_MB(descb), desc_NB(descb), grid_order, nprow, npcol, MPI_COMM_WORLD); - B = slate_scalapack_submatrix(Bm, Bn, B, ib, jb, descb); + blas_int nprow, npcol, myprow, mypcol; + Cblacs_gridinfo( desc_ctxt( descA ), &nprow, &npcol, &myprow, &mypcol ); + auto A = slate::Matrix::fromScaLAPACK( + desc_m( descA ), desc_n( descA ), A_data, desc_lld( descA ), + desc_mb( descA ), desc_nb( descA ), + grid_order, nprow, npcol, MPI_COMM_WORLD ); + A = slate_scalapack_submatrix( Am, An, A, ia, ja, descA ); + + Cblacs_gridinfo( desc_ctxt( descB ), &nprow, &npcol, &myprow, &mypcol ); + auto B = slate::Matrix::fromScaLAPACK( + desc_m( descB ), desc_n( descB ), B_data, desc_lld( descB ), + desc_mb( descB ), desc_nb( descB ), + grid_order, nprow, npcol, MPI_COMM_WORLD ); + B = slate_scalapack_submatrix( Bm, Bn, B, ib, jb, descB ); if (verbose && myprow == 0 && mypcol == 0) logprintf("%s\n", "getrs"); @@ -129,22 +63,22 @@ void slate_pgetrs(const char* transstr, int n, int nrhs, scalar_t* a, int ia, in // copy pivots from ScaLAPACK local ipiv array to SLATE global Pivots structure { // allocate pivots - int64_t min_mt_nt = std::min(A.mt(), A.nt()); - pivots.resize(min_mt_nt); + int64_t min_mt_nt = std::min( A.mt(), A.nt() ); + pivots.resize( min_mt_nt ); for (int64_t k = 0; k < min_mt_nt; ++k) { - int64_t diag_len = std::min(A.tileMb(k), A.tileNb(k)); - pivots.at(k).resize(diag_len); + int64_t diag_len = std::min( A.tileMb(k), A.tileNb(k) ); + pivots.at( k ).resize( diag_len ); } // transfer local ipiv to local part of pivots - int isrcproc0 = 0; - int nb = desc_MB(desca); // ScaLAPACK style fixed nb - int64_t l_numrows = scalapack_numroc(n, nb, myprow, isrcproc0, nprow); // local number of rows + blas_int isrcproc0 = 0; + blas_int nb = desc_mb( descA ); // ScaLAPACK style fixed nb + int64_t l_numrows = scalapack_numroc( n, nb, myprow, isrcproc0, nprow ); // local number of rows // l_rindx local row index (Scalapack 1-index) // for each local ipiv entry, find corresponding local-pivot information and swap-pivot information - for (int l_ipiv_rindx=1; l_ipiv_rindx <= l_numrows; ++l_ipiv_rindx) { + for (blas_int l_ipiv_rindx=1; l_ipiv_rindx <= l_numrows; ++l_ipiv_rindx) { // for local ipiv index, convert to global indexing - int64_t g_ipiv_rindx = scalapack_indxl2g(&l_ipiv_rindx, &nb, &myprow, &isrcproc0, &nprow); + int64_t g_ipiv_rindx = scalapack_indxl2g( &l_ipiv_rindx, &nb, &myprow, &isrcproc0, &nprow ); // assuming uniform nb from scalapack (note 1-indexing), find global tile, offset int64_t g_ipiv_tile_indx = (g_ipiv_rindx - 1) / nb; int64_t g_ipiv_tile_offset = (g_ipiv_rindx -1) % nb; @@ -155,7 +89,7 @@ void slate_pgetrs(const char* transstr, int n, int nrhs, scalar_t* a, int ia, in int64_t tileIndexSwap = ((ipiv[l_ipiv_rindx - 1] - 1) / nb) - g_ipiv_tile_indx; int64_t elementOffsetSwap = (ipiv[l_ipiv_rindx - 1] - 1) % nb; // in the local pivot object, assign swap information - pivref = Pivot(tileIndexSwap, elementOffsetSwap); + pivref = Pivot( tileIndexSwap, elementOffsetSwap ); // if (verbose) { // printf("[%d,%d] getrs ipiv[%lld=%lld]=%lld -> pivots[%lld][%lld]=(%lld,%lld)\n", // myprow, mypcol, @@ -164,30 +98,102 @@ void slate_pgetrs(const char* transstr, int n, int nrhs, scalar_t* a, int ia, in // llong( g_ipiv_tile_indx ), llong( g_ipiv_tile_offset ), // llong( tileIndexSwap ), llong( elementOffsetSwap )); // } - // fflush(0); + // fflush( 0 ); } // broadcast local pivot information to all processes for (int64_t k = 0; k < min_mt_nt; ++k) { - MPI_Bcast(pivots.at(k).data(), - sizeof(Pivot)*pivots.at(k).size(), - MPI_BYTE, A.tileRank(k, k), A.mpiComm()); + MPI_Bcast( pivots.at( k ).data(), + sizeof(Pivot)*pivots.at( k ).size(), + MPI_BYTE, A.tileRank( k, k ), A.mpiComm() ); } } // apply operators to the matrix auto opA = A; if (trans == slate::Op::Trans) - opA = transpose(A); + opA = transpose( A ); else if (trans == slate::Op::ConjTrans) opA = conj_transpose( A ); // call the SLATE getrs routine - slate::getrs(opA, pivots, B, opts); + slate::getrs( opA, pivots, B, opts); // todo: extract the real info from getrs *info = 0; } +//------------------------------------------------------------------------------ +// Fortran interfaces +// Each Fortran interface calls the type generic slate wrapper. + +extern "C" { + +#define SCALAPACK_psgetrs BLAS_FORTRAN_NAME( psgetrs, PSGETRS ) +void SCALAPACK_psgetrs( + const char* trans, blas_int const* n, blas_int* nrhs, + float* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + blas_int* ipiv, + float* B_data, blas_int const* ib, blas_int const* jb, blas_int const* descB, + blas_int* info ) +{ + slate_pgetrs( + trans, *n, *nrhs, + A_data, *ia, *ja, descA, + ipiv, + B_data, *ib, *jb, descB, + info ); +} + +#define SCALAPACK_pdgetrs BLAS_FORTRAN_NAME( pdgetrs, PDGETRS ) +void SCALAPACK_pdgetrs( + const char* trans, blas_int const* n, blas_int* nrhs, + double* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + blas_int* ipiv, + double* B_data, blas_int const* ib, blas_int const* jb, blas_int const* descB, + blas_int* info ) +{ + slate_pgetrs( + trans, *n, *nrhs, + A_data, *ia, *ja, descA, + ipiv, + B_data, *ib, *jb, descB, + info ); +} + +#define SCALAPACK_pcgetrs BLAS_FORTRAN_NAME( pcgetrs, PCGETRS ) +void SCALAPACK_pcgetrs( + const char* trans, blas_int const* n, blas_int* nrhs, + std::complex* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + blas_int* ipiv, + std::complex* B_data, blas_int const* ib, blas_int const* jb, blas_int const* descB, + blas_int* info ) +{ + slate_pgetrs( + trans, *n, *nrhs, + A_data, *ia, *ja, descA, + ipiv, + B_data, *ib, *jb, descB, + info ); +} + +#define SCALAPACK_pzgetrs BLAS_FORTRAN_NAME( pzgetrs, PZGETRS ) +void SCALAPACK_pzgetrs( + const char* trans, blas_int const* n, blas_int* nrhs, + std::complex* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + blas_int* ipiv, + std::complex* B_data, blas_int const* ib, blas_int const* jb, blas_int const* descB, + blas_int* info ) +{ + slate_pgetrs( + trans, *n, *nrhs, + A_data, *ia, *ja, descA, + ipiv, + B_data, *ib, *jb, descB, + info ); +} + +} // extern "C" + } // namespace scalapack_api } // namespace slate diff --git a/scalapack_api/scalapack_heev.cc b/scalapack_api/scalapack_heev.cc index 336ed23e..e4c865aa 100644 --- a/scalapack_api/scalapack_heev.cc +++ b/scalapack_api/scalapack_heev.cc @@ -8,99 +8,23 @@ namespace slate { namespace scalapack_api { -// ----------------------------------------------------------------------------- - -// Required CBLACS calls -extern "C" void Cblacs_gridinfo(int context, int* np_row, int* np_col, int* my_row, int* my_col); - -// Type generic function calls the SLATE routine -template< typename scalar_t > -void slate_pheev(const char* jobzstr, const char* uplostr, int n, scalar_t* a, int ia, int ja, int* desca, blas::real_type* w, scalar_t* z, int iz, int jz, int* descz, scalar_t* work, int lwork, blas::real_type* rwork, int lrwork, int* info); - -// ----------------------------------------------------------------------------- -// C interfaces (FORTRAN_UPPER, FORTRAN_LOWER, FORTRAN_UNDERSCORE) -// Each C interface calls the type generic slate_pher2k - -extern "C" void PSSYEV(const char* jobzstr, const char* uplostr, int* n, float* a, int* ia, int* ja, int* desca, float* w, float* z, int* iz, int* jz, int* descz, float* work, int* lwork, int* info) -{ - float dummy; - slate_pheev(jobzstr, uplostr, *n, a, *ia, *ja, desca, w, z, *iz, *jz, descz, work, *lwork, &dummy, 1, info); -} - -extern "C" void pssyev(const char* jobzstr, const char* uplostr, int* n, float* a, int* ia, int* ja, int* desca, float* w, float* z, int* iz, int* jz, int* descz, float* work, int* lwork, int* info) -{ - float dummy; - slate_pheev(jobzstr, uplostr, *n, a, *ia, *ja, desca, w, z, *iz, *jz, descz, work, *lwork, &dummy, 1, info); -} - -extern "C" void pssyev_(const char* jobzstr, const char* uplostr, int* n, float* a, int* ia, int* ja, int* desca, float* w, float* z, int* iz, int* jz, int* descz, float* work, int* lwork, int* info) -{ - float dummy; - slate_pheev(jobzstr, uplostr, *n, a, *ia, *ja, desca, w, z, *iz, *jz, descz, work, *lwork, &dummy, 1, info); -} - -// ----------------------------------------------------------------------------- - -extern "C" void PDSYEV(const char* jobzstr, const char* uplostr, int* n, double* a, int* ia, int* ja, int* desca, double* w, double* z, int* iz, int* jz, int* descz, double* work, int* lwork, int* info) -{ - double dummy; - slate_pheev(jobzstr, uplostr, *n, a, *ia, *ja, desca, w, z, *iz, *jz, descz, work, *lwork, &dummy, 1, info); -} - -extern "C" void pdsyev(const char* jobzstr, const char* uplostr, int* n, double* a, int* ia, int* ja, int* desca, double* w, double* z, int* iz, int* jz, int* descz, double* work, int* lwork, int* info) -{ - double dummy; - slate_pheev(jobzstr, uplostr, *n, a, *ia, *ja, desca, w, z, *iz, *jz, descz, work, *lwork, &dummy, 1, info); -} - -extern "C" void pdsyev_(const char* jobzstr, const char* uplostr, int* n, double* a, int* ia, int* ja, int* desca, double* w, double* z, int* iz, int* jz, int* descz, double* work, int* lwork, int* info) -{ - double dummy; - slate_pheev(jobzstr, uplostr, *n, a, *ia, *ja, desca, w, z, *iz, *jz, descz, work, *lwork, &dummy, 1, info); -} - -// ----------------------------------------------------------------------------- - -extern "C" void PCHEEV(const char* jobzstr, const char* uplostr, int* n, std::complex* a, int* ia, int* ja, int* desca, float* w, std::complex* z, int* iz, int* jz, int* descz, std::complex* work, int* lwork, float* rwork, int* lrwork, int* info) -{ - slate_pheev(jobzstr, uplostr, *n, a, *ia, *ja, desca, w, z, *iz, *jz, descz, work, *lwork, rwork, *lrwork, info); -} - -extern "C" void pcheev(const char* jobzstr, const char* uplostr, int* n, std::complex* a, int* ia, int* ja, int* desca, float* w, std::complex* z, int* iz, int* jz, int* descz, std::complex* work, int* lwork, float* rwork, int* lrwork, int* info) -{ - slate_pheev(jobzstr, uplostr, *n, a, *ia, *ja, desca, w, z, *iz, *jz, descz, work, *lwork, rwork, *lrwork, info); -} - -extern "C" void pcheev_(const char* jobzstr, const char* uplostr, int* n, std::complex* a, int* ia, int* ja, int* desca, float* w, std::complex* z, int* iz, int* jz, int* descz, std::complex* work, int* lwork, float* rwork, int* lrwork, int* info) -{ - slate_pheev(jobzstr, uplostr, *n, a, *ia, *ja, desca, w, z, *iz, *jz, descz, work, *lwork, rwork, *lrwork, info); -} - -// ----------------------------------------------------------------------------- - -extern "C" void PZHEEV(const char* jobzstr, const char* uplostr, int* n, std::complex* a, int* ia, int* ja, int* desca, double* w, std::complex* z, int* iz, int* jz, int* descz, std::complex* work, int* lwork, double* rwork, int* lrwork, int* info) -{ - slate_pheev(jobzstr, uplostr, *n, a, *ia, *ja, desca, w, z, *iz, *jz, descz, work, *lwork, rwork, *lrwork, info); -} - -extern "C" void pzheev(const char* jobzstr, const char* uplostr, int* n, std::complex* a, int* ia, int* ja, int* desca, double* w, std::complex* z, int* iz, int* jz, int* descz, std::complex* work, int* lwork, double* rwork, int* lrwork, int* info) -{ - slate_pheev(jobzstr, uplostr, *n, a, *ia, *ja, desca, w, z, *iz, *jz, descz, work, *lwork, rwork, *lrwork, info); -} - -extern "C" void pzheev_(const char* jobzstr, const char* uplostr, int* n, std::complex* a, int* ia, int* ja, int* desca, double* w, std::complex* z, int* iz, int* jz, int* descz, std::complex* work, int* lwork, double* rwork, int* lrwork, int* info) -{ - slate_pheev(jobzstr, uplostr, *n, a, *ia, *ja, desca, w, z, *iz, *jz, descz, work, *lwork, rwork, *lrwork, info); -} - -// ----------------------------------------------------------------------------- -template< typename scalar_t > -void slate_pheev(const char* jobzstr, const char* uplostr, int n, scalar_t* a, int ia, int ja, int* desca, blas::real_type* w, scalar_t* z, int iz, int jz, int* descz, scalar_t* work, int lwork, blas::real_type* rwork, int lrwork, int* info) +//------------------------------------------------------------------------------ +/// SLATE ScaLAPACK wrapper sets up SLATE matrices from ScaLAPACK descriptors +/// and calls SLATE. +template +void slate_pheev( + const char* jobz_str, const char* uplo_str, blas_int n, + scalar_t* A_data, blas_int ia, blas_int ja, blas_int const* descA, + blas::real_type* Lambda, + scalar_t* Z_data, blas_int iz, blas_int jz, blas_int const* descZ, + scalar_t* work, blas_int lwork, + blas::real_type* rwork, blas_int lrwork, + blas_int* info ) { Uplo uplo{}; Job jobz{}; - from_string( std::string( 1, uplostr[0] ), &uplo ); - from_string( std::string( 1, jobzstr[0] ), &jobz ); + from_string( std::string( 1, uplo_str[0] ), &uplo ); + from_string( std::string( 1, jobz_str[0] ), &jobz ); slate::Target target = TargetConfig::value(); int verbose = VerboseConfig::value(); @@ -112,8 +36,8 @@ void slate_pheev(const char* jobzstr, const char* uplostr, int n, scalar_t* a, i // todo: extract the real info from heev *info = 0; - int nprow, npcol, myprow, mypcol; - Cblacs_gridinfo(desc_CTXT(desca), &nprow, &npcol, &myprow, &mypcol); + blas_int nprow, npcol, myprow, mypcol; + Cblacs_gridinfo( desc_ctxt( descA ), &nprow, &npcol, &myprow, &mypcol ); if (verbose && myprow == 0 && mypcol == 0) logprintf("%s\n", "heev"); @@ -130,14 +54,19 @@ void slate_pheev(const char* jobzstr, const char* uplostr, int n, scalar_t* a, i int64_t Zn = n; // create SLATE matrices from the ScaLAPACK layouts - auto A = slate::HermitianMatrix::fromScaLAPACK(uplo, desc_N(desca), a, desc_LLD(desca), desc_NB(desca), grid_order, nprow, npcol, MPI_COMM_WORLD); - A = slate_scalapack_submatrix(Am, An, A, ia, ja, desca); + auto A = slate::HermitianMatrix::fromScaLAPACK( + uplo, desc_n( descA ), A_data, desc_lld( descA ), desc_nb( descA ), + grid_order, nprow, npcol, MPI_COMM_WORLD ); + A = slate_scalapack_submatrix( Am, An, A, ia, ja, descA ); slate::Matrix Z; if (jobz == lapack::Job::Vec) { - Cblacs_gridinfo(desc_CTXT(descz), &nprow, &npcol, &myprow, &mypcol); - Z = slate::Matrix::fromScaLAPACK(desc_M(descz), desc_N(descz), z, desc_LLD(descz), desc_MB(descz), desc_NB(descz), grid_order, nprow, npcol, MPI_COMM_WORLD); - Z = slate_scalapack_submatrix(Zm, Zn, Z, iz, jz, descz); + Cblacs_gridinfo( desc_ctxt( descZ ), &nprow, &npcol, &myprow, &mypcol ); + Z = slate::Matrix::fromScaLAPACK( + desc_m( descZ ), desc_n( descZ ), Z_data, desc_lld( descZ ), + desc_mb( descZ ), desc_nb( descZ ), + grid_order, nprow, npcol, MPI_COMM_WORLD ); + Z = slate_scalapack_submatrix( Zm, Zn, Z, iz, jz, descZ ); } std::vector< blas::real_type > Lambda_( n ); @@ -150,8 +79,88 @@ void slate_pheev(const char* jobzstr, const char* uplostr, int n, scalar_t* a, i {slate::Option::InnerBlocking, ib} }); - std::copy(Lambda_.begin(), Lambda_.end(), w); + std::copy( Lambda_.begin(), Lambda_.end(), Lambda ); } +//------------------------------------------------------------------------------ +// Fortran interfaces +// Each Fortran interface calls the type generic slate wrapper. + +extern "C" { + +#define SCALAPACK_pssyev BLAS_FORTRAN_NAME( pssyev, PSSYEV ) +void SCALAPACK_pssyev( + const char* jobz, const char* uplo, blas_int const* n, + float* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + float* Lambda, + float* Z_data, blas_int const* iz, blas_int const* jz, blas_int const* descZ, + float* work, blas_int const* lwork, + blas_int* info ) +{ + float dummy; + slate_pheev( + jobz, uplo, *n, + A_data, *ia, *ja, descA, + Lambda, + Z_data, *iz, *jz, descZ, + work, *lwork, &dummy, 1, info ); +} + +#define SCALAPACK_pdsyev BLAS_FORTRAN_NAME( pdsyev, PDSYEV ) +void SCALAPACK_pdsyev( + const char* jobz, const char* uplo, blas_int const* n, + double* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + double* Lambda, + double* Z_data, blas_int const* iz, blas_int const* jz, blas_int const* descZ, + double* work, blas_int const* lwork, + blas_int* info ) +{ + double dummy; + slate_pheev( + jobz, uplo, *n, + A_data, *ia, *ja, descA, + Lambda, + Z_data, *iz, *jz, descZ, + work, *lwork, &dummy, 1, info ); +} + +#define SCALAPACK_pcheev BLAS_FORTRAN_NAME( pcheev, PCHEEV ) +void SCALAPACK_pcheev( + const char* jobz, const char* uplo, blas_int const* n, + std::complex* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + float* Lambda, + std::complex* Z_data, blas_int const* iz, blas_int const* jz, blas_int const* descZ, + std::complex* work, blas_int const* lwork, + float* rwork, blas_int const* lrwork, + blas_int* info ) +{ + slate_pheev( + jobz, uplo, *n, + A_data, *ia, *ja, descA, + Lambda, + Z_data, *iz, *jz, descZ, + work, *lwork, rwork, *lrwork, info ); +} + +#define SCALAPACK_pzheev BLAS_FORTRAN_NAME( pzheev, PZHEEV ) +void SCALAPACK_pzheev( + const char* jobz, const char* uplo, blas_int const* n, + std::complex* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + double* Lambda, + std::complex* Z_data, blas_int const* iz, blas_int const* jz, blas_int const* descZ, + std::complex* work, blas_int const* lwork, + double* rwork, blas_int const* lrwork, + blas_int* info ) +{ + slate_pheev( + jobz, uplo, *n, + A_data, *ia, *ja, descA, + Lambda, + Z_data, *iz, *jz, descZ, + work, *lwork, rwork, *lrwork, info ); +} + +} // extern "C" + } // namespace scalapack_api } // namespace slate diff --git a/scalapack_api/scalapack_heevd.cc b/scalapack_api/scalapack_heevd.cc index f7f4676f..c00dea75 100644 --- a/scalapack_api/scalapack_heevd.cc +++ b/scalapack_api/scalapack_heevd.cc @@ -8,99 +8,23 @@ namespace slate { namespace scalapack_api { -// ----------------------------------------------------------------------------- - -// Required CBLACS calls -extern "C" void Cblacs_gridinfo(int context, int* np_row, int* np_col, int* my_row, int* my_col); - -// Type generic function calls the SLATE routine -template< typename scalar_t > -void slate_pheevd(const char* jobzstr, const char* uplostr, int n, scalar_t* a, int ia, int ja, int* desca, blas::real_type* w, scalar_t* z, int iz, int jz, int* descz, scalar_t* work, int lwork, blas::real_type* rwork, int lrwork, int* iwork, int liwork, int* info); - -// ----------------------------------------------------------------------------- -// C interfaces (FORTRAN_UPPER, FORTRAN_LOWER, FORTRAN_UNDERSCORE) -// Each C interface calls the type generic slate_pher2k - -extern "C" void PSSYEVD(const char* jobzstr, const char* uplostr, int* n, float* a, int* ia, int* ja, int* desca, float* w, float* z, int* iz, int* jz, int* descz, float* work, int* lwork, int* iwork, int* liwork, int* info) -{ - float dummy; - slate_pheevd(jobzstr, uplostr, *n, a, *ia, *ja, desca, w, z, *iz, *jz, descz, work, *lwork, &dummy, 1, iwork, *liwork, info); -} - -extern "C" void pssyevd(const char* jobzstr, const char* uplostr, int* n, float* a, int* ia, int* ja, int* desca, float* w, float* z, int* iz, int* jz, int* descz, float* work, int* lwork, int* iwork, int* liwork, int* info) -{ - float dummy; - slate_pheevd(jobzstr, uplostr, *n, a, *ia, *ja, desca, w, z, *iz, *jz, descz, work, *lwork, &dummy, 1, iwork, *liwork, info); -} - -extern "C" void pssyevd_(const char* jobzstr, const char* uplostr, int* n, float* a, int* ia, int* ja, int* desca, float* w, float* z, int* iz, int* jz, int* descz, float* work, int* lwork, int* iwork, int* liwork, int* info) -{ - float dummy; - slate_pheevd(jobzstr, uplostr, *n, a, *ia, *ja, desca, w, z, *iz, *jz, descz, work, *lwork, &dummy, 1, iwork, *liwork, info); -} - -// ----------------------------------------------------------------------------- - -extern "C" void PDSYEVD(const char* jobzstr, const char* uplostr, int* n, double* a, int* ia, int* ja, int* desca, double* w, double* z, int* iz, int* jz, int* descz, double* work, int* lwork, int* iwork, int* liwork, int* info) -{ - double dummy; - slate_pheevd(jobzstr, uplostr, *n, a, *ia, *ja, desca, w, z, *iz, *jz, descz, work, *lwork, &dummy, 1, iwork, *liwork, info); -} - -extern "C" void pdsyevd(const char* jobzstr, const char* uplostr, int* n, double* a, int* ia, int* ja, int* desca, double* w, double* z, int* iz, int* jz, int* descz, double* work, int* lwork, int* iwork, int* liwork, int* info) -{ - double dummy; - slate_pheevd(jobzstr, uplostr, *n, a, *ia, *ja, desca, w, z, *iz, *jz, descz, work, *lwork, &dummy, 1, iwork, *liwork, info); -} - -extern "C" void pdsyevd_(const char* jobzstr, const char* uplostr, int* n, double* a, int* ia, int* ja, int* desca, double* w, double* z, int* iz, int* jz, int* descz, double* work, int* lwork, int* iwork, int* liwork, int* info) -{ - double dummy; - slate_pheevd(jobzstr, uplostr, *n, a, *ia, *ja, desca, w, z, *iz, *jz, descz, work, *lwork, &dummy, 1, iwork, *liwork, info); -} - -// ----------------------------------------------------------------------------- - -extern "C" void PCHEEVD(const char* jobzstr, const char* uplostr, int* n, std::complex* a, int* ia, int* ja, int* desca, float* w, std::complex* z, int* iz, int* jz, int* descz, std::complex* work, int* lwork, float* rwork, int* lrwork, int* iwork, int* liwork, int* info) -{ - slate_pheevd(jobzstr, uplostr, *n, a, *ia, *ja, desca, w, z, *iz, *jz, descz, work, *lwork, rwork, *lrwork, iwork, *liwork, info); -} - -extern "C" void pcheevd(const char* jobzstr, const char* uplostr, int* n, std::complex* a, int* ia, int* ja, int* desca, float* w, std::complex* z, int* iz, int* jz, int* descz, std::complex* work, int* lwork, float* rwork, int* lrwork, int* iwork, int* liwork, int* info) -{ - slate_pheevd(jobzstr, uplostr, *n, a, *ia, *ja, desca, w, z, *iz, *jz, descz, work, *lwork, rwork, *lrwork, iwork, *liwork, info); -} - -extern "C" void pcheevd_(const char* jobzstr, const char* uplostr, int* n, std::complex* a, int* ia, int* ja, int* desca, float* w, std::complex* z, int* iz, int* jz, int* descz, std::complex* work, int* lwork, float* rwork, int* lrwork, int* iwork, int* liwork, int* info) -{ - slate_pheevd(jobzstr, uplostr, *n, a, *ia, *ja, desca, w, z, *iz, *jz, descz, work, *lwork, rwork, *lrwork, iwork, *liwork, info); -} - -// ----------------------------------------------------------------------------- - -extern "C" void PZHEEVD(const char* jobzstr, const char* uplostr, int* n, std::complex* a, int* ia, int* ja, int* desca, double* w, std::complex* z, int* iz, int* jz, int* descz, std::complex* work, int* lwork, double* rwork, int* lrwork, int* iwork, int* liwork, int* info) -{ - slate_pheevd(jobzstr, uplostr, *n, a, *ia, *ja, desca, w, z, *iz, *jz, descz, work, *lwork, rwork, *lrwork, iwork, *liwork, info); -} - -extern "C" void pzheevd(const char* jobzstr, const char* uplostr, int* n, std::complex* a, int* ia, int* ja, int* desca, double* w, std::complex* z, int* iz, int* jz, int* descz, std::complex* work, int* lwork, double* rwork, int* lrwork, int* iwork, int* liwork, int* info) -{ - slate_pheevd(jobzstr, uplostr, *n, a, *ia, *ja, desca, w, z, *iz, *jz, descz, work, *lwork, rwork, *lrwork, iwork, *liwork, info); -} - -extern "C" void pzheevd_(const char* jobzstr, const char* uplostr, int* n, std::complex* a, int* ia, int* ja, int* desca, double* w, std::complex* z, int* iz, int* jz, int* descz, std::complex* work, int* lwork, double* rwork, int* lrwork, int* iwork, int* liwork, int* info) -{ - slate_pheevd(jobzstr, uplostr, *n, a, *ia, *ja, desca, w, z, *iz, *jz, descz, work, *lwork, rwork, *lrwork, iwork, *liwork, info); -} - -// ----------------------------------------------------------------------------- -template< typename scalar_t > -void slate_pheevd(const char* jobzstr, const char* uplostr, int n, scalar_t* a, int ia, int ja, int* desca, blas::real_type* w, scalar_t* z, int iz, int jz, int* descz, scalar_t* work, int lwork, blas::real_type* rwork, int lrwork, int* iwork, int liwork, int* info) +//------------------------------------------------------------------------------ +/// SLATE ScaLAPACK wrapper sets up SLATE matrices from ScaLAPACK descriptors +/// and calls SLATE. +template +void slate_pheevd( + const char* jobz_str, const char* uplo_str, blas_int n, + scalar_t* A_data, blas_int ia, blas_int ja, blas_int const* descA, + blas::real_type* Lambda, + scalar_t* Z_data, blas_int iz, blas_int jz, blas_int const* descZ, + scalar_t* work, blas_int lwork, blas::real_type* rwork, blas_int lrwork, + blas_int* iwork, blas_int liwork, + blas_int* info ) { Uplo uplo{}; Job jobz{}; - from_string( std::string( 1, uplostr[0] ), &uplo ); - from_string( std::string( 1, jobzstr[0] ), &jobz ); + from_string( std::string( 1, uplo_str[0] ), &uplo ); + from_string( std::string( 1, jobz_str[0] ), &jobz ); slate::Target target = TargetConfig::value(); int verbose = VerboseConfig::value(); @@ -112,8 +36,8 @@ void slate_pheevd(const char* jobzstr, const char* uplostr, int n, scalar_t* a, // todo: extract the real info from heevd *info = 0; - int nprow, npcol, myprow, mypcol; - Cblacs_gridinfo(desc_CTXT(desca), &nprow, &npcol, &myprow, &mypcol); + blas_int nprow, npcol, myprow, mypcol; + Cblacs_gridinfo( desc_ctxt( descA ), &nprow, &npcol, &myprow, &mypcol ); if (verbose && myprow == 0 && mypcol == 0) logprintf("%s\n", "heevd"); @@ -131,14 +55,19 @@ void slate_pheevd(const char* jobzstr, const char* uplostr, int n, scalar_t* a, int64_t Zn = n; // create SLATE matrices from the ScaLAPACK layouts - auto A = slate::HermitianMatrix::fromScaLAPACK(uplo, desc_N(desca), a, desc_LLD(desca), desc_NB(desca), grid_order, nprow, npcol, MPI_COMM_WORLD); - A = slate_scalapack_submatrix(Am, An, A, ia, ja, desca); + auto A = slate::HermitianMatrix::fromScaLAPACK( + uplo, desc_n( descA ), A_data, desc_lld( descA ), desc_nb( descA ), + grid_order, nprow, npcol, MPI_COMM_WORLD ); + A = slate_scalapack_submatrix( Am, An, A, ia, ja, descA ); slate::Matrix Z; if (jobz == lapack::Job::Vec) { - Cblacs_gridinfo(desc_CTXT(descz), &nprow, &npcol, &myprow, &mypcol); - Z = slate::Matrix::fromScaLAPACK(desc_M(descz), desc_N(descz), z, desc_LLD(descz), desc_MB(descz), desc_NB(descz), grid_order, nprow, npcol, MPI_COMM_WORLD); - Z = slate_scalapack_submatrix(Zm, Zn, Z, iz, jz, descz); + Cblacs_gridinfo( desc_ctxt( descZ ), &nprow, &npcol, &myprow, &mypcol ); + Z = slate::Matrix::fromScaLAPACK( + desc_m( descZ ), desc_n( descZ ), Z_data, desc_lld( descZ ), + desc_mb( descZ ), desc_nb( descZ ), + grid_order, nprow, npcol, MPI_COMM_WORLD ); + Z = slate_scalapack_submatrix( Zm, Zn, Z, iz, jz, descZ ); } std::vector< blas::real_type > Lambda_( n ); @@ -151,8 +80,92 @@ void slate_pheevd(const char* jobzstr, const char* uplostr, int n, scalar_t* a, {slate::Option::InnerBlocking, ib} }); - std::copy(Lambda_.begin(), Lambda_.end(), w); + std::copy( Lambda_.begin(), Lambda_.end(), Lambda ); +} + +//------------------------------------------------------------------------------ +// Fortran interfaces +// Each Fortran interface calls the type generic slate wrapper. + +extern "C" { + +#define SCALAPACK_pssyevd BLAS_FORTRAN_NAME( pssyevd, PSSYEVD ) +void SCALAPACK_pssyevd( + const char* jobz_str, const char* uplo, blas_int const* n, + float* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + float* Lambda, + float* Z_data, blas_int const* iz, blas_int const* jz, blas_int const* descZ, + float* work, blas_int const* lwork, + blas_int* iwork, blas_int const* liwork, + blas_int* info ) +{ + float dummy; + slate_pheevd( + jobz_str, uplo, *n, + A_data, *ia, *ja, descA, + Lambda, + Z_data, *iz, *jz, descZ, + work, *lwork, &dummy, 1, iwork, *liwork, info ); +} + +#define SCALAPACK_pdsyevd BLAS_FORTRAN_NAME( pdsyevd, PDSYEVD ) +void SCALAPACK_pdsyevd( + const char* jobz_str, const char* uplo, blas_int const* n, + double* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + double* Lambda, + double* Z_data, blas_int const* iz, blas_int const* jz, blas_int const* descZ, + double* work, blas_int const* lwork, + blas_int* iwork, blas_int const* liwork, + blas_int* info ) +{ + double dummy; + slate_pheevd( + jobz_str, uplo, *n, + A_data, *ia, *ja, descA, + Lambda, + Z_data, *iz, *jz, descZ, + work, *lwork, &dummy, 1, iwork, *liwork, info ); +} + +#define SCALAPACK_pcheevd BLAS_FORTRAN_NAME( pcheevd, PCHEEVD ) +void SCALAPACK_pcheevd( + const char* jobz_str, const char* uplo, blas_int const* n, + std::complex* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + float* Lambda, + std::complex* Z_data, blas_int const* iz, blas_int const* jz, blas_int const* descZ, + std::complex* work, blas_int const* lwork, + float* rwork, blas_int const* lrwork, + blas_int* iwork, blas_int const* liwork, + blas_int* info ) +{ + slate_pheevd( + jobz_str, uplo, *n, + A_data, *ia, *ja, descA, + Lambda, + Z_data, *iz, *jz, descZ, + work, *lwork, rwork, *lrwork, iwork, *liwork, info ); +} + +#define SCALAPACK_pzheevd BLAS_FORTRAN_NAME( pzheevd, PZHEEVD ) +void SCALAPACK_pzheevd( + const char* jobz_str, const char* uplo, blas_int const* n, + std::complex* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + double* Lambda, + std::complex* Z_data, blas_int const* iz, blas_int const* jz, blas_int const* descZ, + std::complex* work, blas_int const* lwork, + double* rwork, blas_int const* lrwork, + blas_int* iwork, blas_int const* liwork, + blas_int* info ) +{ + slate_pheevd( + jobz_str, uplo, *n, + A_data, *ia, *ja, descA, + Lambda, + Z_data, *iz, *jz, descZ, + work, *lwork, rwork, *lrwork, iwork, *liwork, info ); } +} // extern "C" + } // namespace scalapack_api } // namespace slate diff --git a/scalapack_api/scalapack_hemm.cc b/scalapack_api/scalapack_hemm.cc index bd7e513e..432d95a0 100644 --- a/scalapack_api/scalapack_hemm.cc +++ b/scalapack_api/scalapack_hemm.cc @@ -8,77 +8,22 @@ namespace slate { namespace scalapack_api { -// ----------------------------------------------------------------------------- - -// Required CBLACS calls -extern "C" void Cblacs_gridinfo(int context, int* np_row, int* np_col, int* my_row, int* my_col); - -// Declarations -template< typename scalar_t > -void slate_phemm(const char* side, const char* uplo, int m, int n, scalar_t alpha, scalar_t* a, int ia, int ja, int* desca, scalar_t* b, int ib, int jb, int* descb, scalar_t beta, scalar_t* c, int ic, int jc, int* descc); - -// ----------------------------------------------------------------------------- -// Each C interface for all Fortran interfaces (FORTRAN_UPPER, FORTRAN_LOWER, FORTRAN_UNDERSCORE) -// Each C interface calls the type-generic C++ slate_pgemm routine. - -extern "C" void PCHEMM(const char* side, const char* uplo, int* m, int* n, std::complex* alpha, std::complex* a, int* ia, int* ja, int* desca, std::complex* b, int* ib, int* jb, int* descb, std::complex* beta, std::complex* c, int* ic, int* jc, int* descc) -{ - slate_phemm(side, uplo, *m, *n, *alpha, a, *ia, *ja, desca, b, *ib, *jb, descb, *beta, c, *ic, *jc, descc); -} - -extern "C" void pchemm(const char* side, const char* uplo, int* m, int* n, std::complex* alpha, std::complex* a, int* ia, int* ja, int* desca, std::complex* b, int* ib, int* jb, int* descb, std::complex* beta, std::complex* c, int* ic, int* jc, int* descc) -{ - slate_phemm(side, uplo, *m, *n, *alpha, a, *ia, *ja, desca, b, *ib, *jb, descb, *beta, c, *ic, *jc, descc); -} - -extern "C" void pchemm_(const char* side, const char* uplo, int* m, int* n, std::complex* alpha, std::complex* a, int* ia, int* ja, int* desca, std::complex* b, int* ib, int* jb, int* descb, std::complex* beta, std::complex* c, int* ic, int* jc, int* descc) -{ - slate_phemm(side, uplo, *m, *n, *alpha, a, *ia, *ja, desca, b, *ib, *jb, descb, *beta, c, *ic, *jc, descc); -} - -// ----------------------------------------------------------------------------- - -extern "C" void PZHEMM(const char* side, const char* uplo, int* m, int* n, std::complex* alpha, std::complex* a, int* ia, int* ja, int* desca, std::complex* b, int* ib, int* jb, int* descb, std::complex* beta, std::complex* c, int* ic, int* jc, int* descc) -{ - slate_phemm(side, uplo, *m, *n, *alpha, a, *ia, *ja, desca, b, *ib, *jb, descb, *beta, c, *ic, *jc, descc); -} - -extern "C" void pzhemm(const char* side, const char* uplo, int* m, int* n, std::complex* alpha, std::complex* a, int* ia, int* ja, int* desca, std::complex* b, int* ib, int* jb, int* descb, std::complex* beta, std::complex* c, int* ic, int* jc, int* descc) -{ - slate_phemm(side, uplo, *m, *n, *alpha, a, *ia, *ja, desca, b, *ib, *jb, descb, *beta, c, *ic, *jc, descc); -} - -extern "C" void pzhemm_(const char* side, const char* uplo, int* m, int* n, std::complex* alpha, std::complex* a, int* ia, int* ja, int* desca, std::complex* b, int* ib, int* jb, int* descb, std::complex* beta, std::complex* c, int* ic, int* jc, int* descc) -{ - slate_phemm(side, uplo, *m, *n, *alpha, a, *ia, *ja, desca, b, *ib, *jb, descb, *beta, c, *ic, *jc, descc); -} - -// ----------------------------------------------------------------------------- -// Exposed type-specific API - -#define slate_pchemm BLAS_FORTRAN_NAME( slate_pchemm, SLATE_PCHEMM ) -#define slate_pzhemm BLAS_FORTRAN_NAME( slate_pzhemm, SLATE_PZHEMM ) - -extern "C" void slate_pchemm(const char* side, const char* uplo, int* m, int* n, std::complex* alpha, std::complex* a, int* ia, int* ja, int* desca, std::complex* b, int* ib, int* jb, int* descb, std::complex* beta, std::complex* c, int* ic, int* jc, int* descc) -{ - slate_phemm(side, uplo, *m, *n, *alpha, a, *ia, *ja, desca, b, *ib, *jb, descb, *beta, c, *ic, *jc, descc); -} - -extern "C" void slate_pzhemm(const char* side, const char* uplo, int* m, int* n, std::complex* alpha, std::complex* a, int* ia, int* ja, int* desca, std::complex* b, int* ib, int* jb, int* descb, std::complex* beta, std::complex* c, int* ic, int* jc, int* descc) -{ - slate_phemm(side, uplo, *m, *n, *alpha, a, *ia, *ja, desca, b, *ib, *jb, descb, *beta, c, *ic, *jc, descc); -} - -// ----------------------------------------------------------------------------- - -// Type generic function calls the SLATE routine -template< typename scalar_t > -void slate_phemm(const char* sidestr, const char* uplostr, int m, int n, scalar_t alpha, scalar_t* a, int ia, int ja, int* desca, scalar_t* b, int ib, int jb, int* descb, scalar_t beta, scalar_t* c, int ic, int jc, int* descc) +//------------------------------------------------------------------------------ +/// SLATE ScaLAPACK wrapper sets up SLATE matrices from ScaLAPACK descriptors +/// and calls SLATE. +template +void slate_phemm( + const char* side_str, const char* uplo_str, + blas_int m, blas_int n, scalar_t alpha, + scalar_t* A_data, blas_int ia, blas_int ja, blas_int const* descA, + scalar_t* B_data, blas_int ib, blas_int jb, blas_int const* descB, + scalar_t beta, + scalar_t* C_data, blas_int ic, blas_int jc, blas_int const* descC ) { Side side{}; Uplo uplo{}; - from_string( std::string( 1, sidestr[0] ), &side ); - from_string( std::string( 1, uplostr[0] ), &uplo ); + from_string( std::string( 1, side_str[0] ), &side ); + from_string( std::string( 1, uplo_str[0] ), &uplo ); slate::Target target = TargetConfig::value(); int verbose = VerboseConfig::value(); @@ -93,34 +38,83 @@ void slate_phemm(const char* sidestr, const char* uplostr, int m, int n, scalar_ int64_t Cn = n; // create SLATE matrices from the ScaLAPACK layouts - int nprow, npcol, myprow, mypcol; - Cblacs_gridinfo(desc_CTXT(desca), &nprow, &npcol, &myprow, &mypcol); - auto AH = slate::HermitianMatrix::fromScaLAPACK(uplo, desc_N(desca), a, desc_LLD(desca), desc_NB(desca), grid_order, nprow, npcol, MPI_COMM_WORLD); - AH = slate_scalapack_submatrix(Am, An, AH, ia, ja, desca); - - Cblacs_gridinfo(desc_CTXT(descb), &nprow, &npcol, &myprow, &mypcol); - auto B = slate::Matrix::fromScaLAPACK(desc_M(descb), desc_N(descb), b, desc_LLD(descb), desc_MB(descb), desc_NB(descb), grid_order, nprow, npcol, MPI_COMM_WORLD); - B = slate_scalapack_submatrix(Bm, Bn, B, ib, jb, descb); - - Cblacs_gridinfo(desc_CTXT(descc), &nprow, &npcol, &myprow, &mypcol); - auto C = slate::Matrix::fromScaLAPACK(desc_M(descc), desc_N(descc), c, desc_LLD(descc), desc_MB(descc), desc_NB(descc), grid_order, nprow, npcol, MPI_COMM_WORLD); - C = slate_scalapack_submatrix(Cm, Cn, C, ic, jc, descc); + blas_int nprow, npcol, myprow, mypcol; + Cblacs_gridinfo( desc_ctxt( descA ), &nprow, &npcol, &myprow, &mypcol ); + auto AH = slate::HermitianMatrix::fromScaLAPACK( + uplo, desc_n( descA ), A_data, desc_lld( descA ), desc_nb( descA ), + grid_order, nprow, npcol, MPI_COMM_WORLD ); + AH = slate_scalapack_submatrix( Am, An, AH, ia, ja, descA ); + + Cblacs_gridinfo( desc_ctxt( descB ), &nprow, &npcol, &myprow, &mypcol ); + auto B = slate::Matrix::fromScaLAPACK( + desc_m( descB ), desc_n( descB ), B_data, desc_lld( descB ), + desc_mb( descB ), desc_nb( descB ), + grid_order, nprow, npcol, MPI_COMM_WORLD ); + B = slate_scalapack_submatrix( Bm, Bn, B, ib, jb, descB ); + + Cblacs_gridinfo( desc_ctxt( descC ), &nprow, &npcol, &myprow, &mypcol ); + auto C = slate::Matrix::fromScaLAPACK( + desc_m( descC ), desc_n( descC ), C_data, desc_lld( descC ), + desc_mb( descC ), desc_nb( descC ), + grid_order, nprow, npcol, MPI_COMM_WORLD ); + C = slate_scalapack_submatrix( Cm, Cn, C, ic, jc, descC ); if (side == blas::Side::Left) - assert(AH.mt() == C.mt()); + assert( AH.mt() == C.mt() ); else - assert(AH.mt() == C.nt()); - assert(B.mt() == C.mt()); - assert(B.nt() == C.nt()); + assert( AH.mt() == C.nt() ); + assert( B.mt() == C.mt() ); + assert( B.nt() == C.nt() ); if (verbose && myprow == 0 && mypcol == 0) logprintf("%s\n", "hemm"); - slate::hemm(side, alpha, AH, B, beta, C, { + slate::hemm( side, alpha, AH, B, beta, C, { {slate::Option::Lookahead, lookahead}, {slate::Option::Target, target} }); } +//------------------------------------------------------------------------------ +// Each Fortran interface for all Fortran interfaces +// Each Fortran interface calls the type-generic C++ slate_pgemm routine. + +extern "C" { + +#define SCALAPACK_pchemm BLAS_FORTRAN_NAME( pchemm, PCHEMM ) +void SCALAPACK_pchemm( + const char* side, const char* uplo, + blas_int const* m, blas_int const* n, std::complex* alpha, + std::complex* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + std::complex* B_data, blas_int const* ib, blas_int const* jb, blas_int const* descB, + std::complex* beta, + std::complex* C_data, blas_int const* ic, blas_int const* jc, blas_int const* descC ) +{ + slate_phemm( + side, uplo, *m, *n, *alpha, + A_data, *ia, *ja, descA, + B_data, *ib, *jb, descB, *beta, + C_data, *ic, *jc, descC ); +} + +#define SCALAPACK_pzhemm BLAS_FORTRAN_NAME( pzhemm, PZHEMM ) +void SCALAPACK_pzhemm( + const char* side, const char* uplo, + blas_int const* m, blas_int const* n, std::complex* alpha, + std::complex* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + std::complex* B_data, blas_int const* ib, blas_int const* jb, blas_int const* descB, + std::complex* beta, + std::complex* C_data, blas_int const* ic, blas_int const* jc, blas_int const* descC + ) +{ + slate_phemm( + side, uplo, *m, *n, *alpha, + A_data, *ia, *ja, descA, + B_data, *ib, *jb, descB, *beta, + C_data, *ic, *jc, descC ); +} + +} // extern "C" + } // namespace scalapack_api } // namespace slate diff --git a/scalapack_api/scalapack_her2k.cc b/scalapack_api/scalapack_her2k.cc index c03dff4a..fc817d7f 100644 --- a/scalapack_api/scalapack_her2k.cc +++ b/scalapack_api/scalapack_her2k.cc @@ -8,61 +8,21 @@ namespace slate { namespace scalapack_api { -// ----------------------------------------------------------------------------- - -// Required CBLACS calls -extern "C" void Cblacs_gridinfo(int context, int* np_row, int* np_col, int* my_row, int* my_col); - -// Declarations -template< typename scalar_t > -void slate_pher2k(const char* uplostr, const char* transstr, int n, int k, scalar_t alpha, scalar_t* a, int ia, int ja, int* desca, scalar_t* b, int ib, int jb, int* descb, blas::real_type beta, scalar_t* c, int ic, int jc, int* descc); - -// ----------------------------------------------------------------------------- -// C interfaces (FORTRAN_UPPER, FORTRAN_LOWER, FORTRAN_UNDERSCORE) -// Each C interface calls the type generic slate_pher2k - -extern "C" void PCHER2K(const char* uplo, const char* trans, int* n, int* k, std::complex* alpha, std::complex* a, int* ia, int* ja, int* desca, std::complex* b, int* ib, int* jb, int* descb, float* beta, std::complex* c, int* ic, int* jc, int* descc) -{ - slate_pher2k(uplo, trans, *n, *k, *alpha, a, *ia, *ja, desca, b, *ib, *jb, descb, *beta, c, *ic, *jc, descc); -} - -extern "C" void pcher2k(const char* uplo, const char* trans, int* n, int* k, std::complex* alpha, std::complex* a, int* ia, int* ja, int* desca, std::complex* b, int* ib, int* jb, int* descb, float* beta, std::complex* c, int* ic, int* jc, int* descc) -{ - slate_pher2k(uplo, trans, *n, *k, *alpha, a, *ia, *ja, desca, b, *ib, *jb, descb, *beta, c, *ic, *jc, descc); -} - -extern "C" void pcher2k_(const char* uplo, const char* trans, int* n, int* k, std::complex* alpha, std::complex* a, int* ia, int* ja, int* desca, std::complex* b, int* ib, int* jb, int* descb, float* beta, std::complex* c, int* ic, int* jc, int* descc) -{ - slate_pher2k(uplo, trans, *n, *k, *alpha, a, *ia, *ja, desca, b, *ib, *jb, descb, *beta, c, *ic, *jc, descc); -} - -// ----------------------------------------------------------------------------- - -extern "C" void PZHER2K(const char* uplo, const char* trans, int* n, int* k, std::complex* alpha, std::complex* a, int* ia, int* ja, int* desca, std::complex* b, int* ib, int* jb, int* descb, double* beta, std::complex* c, int* ic, int* jc, int* descc) -{ - slate_pher2k(uplo, trans, *n, *k, *alpha, a, *ia, *ja, desca, b, *ib, *jb, descb, *beta, c, *ic, *jc, descc); -} - -extern "C" void pzher2k(const char* uplo, const char* trans, int* n, int* k, std::complex* alpha, std::complex* a, int* ia, int* ja, int* desca, std::complex* b, int* ib, int* jb, int* descb, double* beta, std::complex* c, int* ic, int* jc, int* descc) -{ - slate_pher2k(uplo, trans, *n, *k, *alpha, a, *ia, *ja, desca, b, *ib, *jb, descb, *beta, c, *ic, *jc, descc); -} - -extern "C" void pzher2k_(const char* uplo, const char* trans, int* n, int* k, std::complex* alpha, std::complex* a, int* ia, int* ja, int* desca, std::complex* b, int* ib, int* jb, int* descb, double* beta, std::complex* c, int* ic, int* jc, int* descc) -{ - slate_pher2k(uplo, trans, *n, *k, *alpha, a, *ia, *ja, desca, b, *ib, *jb, descb, *beta, c, *ic, *jc, descc); -} - -// ----------------------------------------------------------------------------- - -// Type generic function calls the SLATE routine -template< typename scalar_t > -void slate_pher2k(const char* uplostr, const char* transstr, int n, int k, scalar_t alpha, scalar_t* a, int ia, int ja, int* desca, scalar_t* b, int ib, int jb, int* descb, blas::real_type beta, scalar_t* c, int ic, int jc, int* descc) +//------------------------------------------------------------------------------ +/// SLATE ScaLAPACK wrapper sets up SLATE matrices from ScaLAPACK descriptors +/// and calls SLATE. +template +void slate_pher2k( + const char* uplo_str, const char* trans_str, blas_int n, blas_int k, scalar_t alpha, + scalar_t* A_data, blas_int ia, blas_int ja, blas_int const* descA, + scalar_t* B_data, blas_int ib, blas_int jb, blas_int const* descB, + blas::real_type beta, + scalar_t* C_data, blas_int ic, blas_int jc, blas_int const* descC ) { Uplo uplo{}; Op trans{}; - from_string( std::string( 1, uplostr[0] ), &uplo ); - from_string( std::string( 1, transstr[0] ), &trans ); + from_string( std::string( 1, uplo_str[0] ), &uplo ); + from_string( std::string( 1, trans_str[0] ), &trans ); slate::Target target = TargetConfig::value(); int verbose = VerboseConfig::value(); @@ -77,39 +37,87 @@ void slate_pher2k(const char* uplostr, const char* transstr, int n, int k, scala int64_t Cn = n; // create SLATE matrices from the ScaLAPACK layouts - int nprow, npcol, myprow, mypcol; - Cblacs_gridinfo(desc_CTXT(desca), &nprow, &npcol, &myprow, &mypcol); - auto A = slate::Matrix::fromScaLAPACK(desc_M(desca), desc_N(desca), a, desc_LLD(desca), desc_MB(desca), desc_NB(desca), grid_order, nprow, npcol, MPI_COMM_WORLD); - A = slate_scalapack_submatrix(Am, An, A, ia, ja, desca); - - Cblacs_gridinfo(desc_CTXT(descb), &nprow, &npcol, &myprow, &mypcol); - auto B = slate::Matrix::fromScaLAPACK(desc_M(descb), desc_N(descb), b, desc_LLD(descb), desc_MB(descb), desc_NB(descb), grid_order, nprow, npcol, MPI_COMM_WORLD); - B = slate_scalapack_submatrix(Bm, Bn, B, ib, jb, descb); - - Cblacs_gridinfo(desc_CTXT(descc), &nprow, &npcol, &myprow, &mypcol); - auto CH = slate::HermitianMatrix::fromScaLAPACK(uplo, desc_N(descc), c, desc_LLD(descc), desc_NB(descc), grid_order, nprow, npcol, MPI_COMM_WORLD); - CH = slate_scalapack_submatrix(Cn, Cn, CH, ic, jc, descc); + blas_int nprow, npcol, myprow, mypcol; + Cblacs_gridinfo( desc_ctxt( descA ), &nprow, &npcol, &myprow, &mypcol ); + auto A = slate::Matrix::fromScaLAPACK( + desc_m( descA ), desc_n( descA ), A_data, desc_lld( descA ), + desc_mb( descA ), desc_nb( descA ), + grid_order, nprow, npcol, MPI_COMM_WORLD ); + A = slate_scalapack_submatrix( Am, An, A, ia, ja, descA ); + + Cblacs_gridinfo( desc_ctxt( descB ), &nprow, &npcol, &myprow, &mypcol ); + auto B = slate::Matrix::fromScaLAPACK( + desc_m( descB ), desc_n( descB ), B_data, desc_lld( descB ), + desc_mb( descB ), desc_nb( descB ), + grid_order, nprow, npcol, MPI_COMM_WORLD ); + B = slate_scalapack_submatrix( Bm, Bn, B, ib, jb, descB ); + + Cblacs_gridinfo( desc_ctxt( descC ), &nprow, &npcol, &myprow, &mypcol ); + auto C = slate::HermitianMatrix::fromScaLAPACK( + uplo, desc_n( descC ), C_data, desc_lld( descC ), desc_nb( descC ), + grid_order, nprow, npcol, MPI_COMM_WORLD ); + C = slate_scalapack_submatrix( Cn, Cn, C, ic, jc, descC ); if (trans == blas::Op::Trans) { - A = transpose(A); - B = transpose(B); + A = transpose( A ); + B = transpose( B ); } else if (trans == blas::Op::ConjTrans) { A = conj_transpose( A ); B = conj_transpose( B ); } - assert(A.mt() == CH.mt()); - assert(B.mt() == CH.mt()); - assert(A.nt() == B.nt()); + assert( A.mt() == C.mt() ); + assert( B.mt() == C.mt() ); + assert( A.nt() == B.nt() ); if (verbose && myprow == 0 && mypcol == 0) logprintf("%s\n", "her2k"); - slate::her2k(alpha, A, B, beta, CH, { + slate::her2k( alpha, A, B, beta, C, { {slate::Option::Lookahead, lookahead}, {slate::Option::Target, target} }); } +//------------------------------------------------------------------------------ +// Fortran interfaces +// Each Fortran interface calls the type generic slate wrapper. + +extern "C" { + +#define SCALAPACK_pcher2k BLAS_FORTRAN_NAME( pcher2k, PCHER2K ) +void SCALAPACK_pcher2k( + const char* uplo, const char* trans, + blas_int const* n, blas_int const* k, std::complex* alpha, + std::complex* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + std::complex* B_data, blas_int const* ib, blas_int const* jb, blas_int const* descB, + float* beta, + std::complex* C_data, blas_int const* ic, blas_int const* jc, blas_int const* descC ) +{ + slate_pher2k( + uplo, trans, *n, *k, *alpha, + A_data, *ia, *ja, descA, + B_data, *ib, *jb, descB, *beta, + C_data, *ic, *jc, descC ); +} + +#define SCALAPACK_pzher2k BLAS_FORTRAN_NAME( pzher2k, PZHER2K ) +void SCALAPACK_pzher2k( + const char* uplo, const char* trans, + blas_int const* n, blas_int const* k, std::complex* alpha, + std::complex* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + std::complex* B_data, blas_int const* ib, blas_int const* jb, blas_int const* descB, + double* beta, + std::complex* C_data, blas_int const* ic, blas_int const* jc, blas_int const* descC ) +{ + slate_pher2k( + uplo, trans, *n, *k, *alpha, + A_data, *ia, *ja, descA, + B_data, *ib, *jb, descB, *beta, + C_data, *ic, *jc, descC ); +} + +} // extern "C" + } // namespace scalapack_api } // namespace slate diff --git a/scalapack_api/scalapack_herk.cc b/scalapack_api/scalapack_herk.cc index e35695f4..eb4394b8 100644 --- a/scalapack_api/scalapack_herk.cc +++ b/scalapack_api/scalapack_herk.cc @@ -8,61 +8,21 @@ namespace slate { namespace scalapack_api { -// ----------------------------------------------------------------------------- - -// Required CBLACS calls -extern "C" void Cblacs_gridinfo(int context, int* np_row, int* np_col, int* my_row, int* my_col); - -// Declarations -template< typename scalar_t > -void slate_pherk(const char* uplostr, const char* transstr, int n, int k, blas::real_type alpha, scalar_t* a, int ia, int ja, int* desca, blas::real_type beta, scalar_t* c, int ic, int jc, int* descc); - -// ----------------------------------------------------------------------------- -// C interfaces (FORTRAN_UPPER, FORTRAN_LOWER, FORTRAN_UNDERSCORE) -// Each C interface calls the type generic slate_pherk - -extern "C" void PCHERK(const char* uplo, const char* trans, int* n, int* k, float* alpha, std::complex* a, int* ia, int* ja, int* desca, float* beta, std::complex* c, int* ic, int* jc, int* descc) -{ - slate_pherk(uplo, trans, *n, *k, *alpha, a, *ia, *ja, desca, *beta, c, *ic, *jc, descc); -} - -extern "C" void pcherk(const char* uplo, const char* trans, int* n, int* k, float* alpha, std::complex* a, int* ia, int* ja, int* desca, float* beta, std::complex* c, int* ic, int* jc, int* descc) -{ - slate_pherk(uplo, trans, *n, *k, *alpha, a, *ia, *ja, desca, *beta, c, *ic, *jc, descc); -} - -extern "C" void pcherk_(const char* uplo, const char* trans, int* n, int* k, float* alpha, std::complex* a, int* ia, int* ja, int* desca, float* beta, std::complex* c, int* ic, int* jc, int* descc) -{ - slate_pherk(uplo, trans, *n, *k, *alpha, a, *ia, *ja, desca, *beta, c, *ic, *jc, descc); -} - -// ----------------------------------------------------------------------------- - -extern "C" void PZHERK(const char* uplo, const char* trans, int* n, int* k, double* alpha, std::complex* a, int* ia, int* ja, int* desca, double* beta, std::complex* c, int* ic, int* jc, int* descc) -{ - slate_pherk(uplo, trans, *n, *k, *alpha, a, *ia, *ja, desca, *beta, c, *ic, *jc, descc); -} - -extern "C" void pzherk(const char* uplo, const char* trans, int* n, int* k, double* alpha, std::complex* a, int* ia, int* ja, int* desca, double* beta, std::complex* c, int* ic, int* jc, int* descc) -{ - slate_pherk(uplo, trans, *n, *k, *alpha, a, *ia, *ja, desca, *beta, c, *ic, *jc, descc); -} - -extern "C" void pzherk_(const char* uplo, const char* trans, int* n, int* k, double* alpha, std::complex* a, int* ia, int* ja, int* desca, double* beta, std::complex* c, int* ic, int* jc, int* descc) -{ - slate_pherk(uplo, trans, *n, *k, *alpha, a, *ia, *ja, desca, *beta, c, *ic, *jc, descc); -} - -// ----------------------------------------------------------------------------- - -// Type generic function calls the SLATE routine -template< typename scalar_t > -void slate_pherk(const char* uplostr, const char* transstr, int n, int k, blas::real_type alpha, scalar_t* a, int ia, int ja, int* desca, blas::real_type beta, scalar_t* c, int ic, int jc, int* descc) +//------------------------------------------------------------------------------ +/// SLATE ScaLAPACK wrapper sets up SLATE matrices from ScaLAPACK descriptors +/// and calls SLATE. +template +void slate_pherk( + const char* uplo_str, const char* trans_str, + blas_int n, blas_int k, blas::real_type alpha, + scalar_t* A_data, blas_int ia, blas_int ja, blas_int const* descA, + blas::real_type beta, + scalar_t* C_data, blas_int ic, blas_int jc, blas_int const* descC ) { Uplo uplo{}; Op transA{}; - from_string( std::string( 1, uplostr[0] ), &uplo ); - from_string( std::string( 1, transstr[0] ), &transA ); + from_string( std::string( 1, uplo_str[0] ), &uplo ); + from_string( std::string( 1, trans_str[0] ), &transA ); slate::Target target = TargetConfig::value(); int verbose = VerboseConfig::value(); @@ -76,29 +36,70 @@ void slate_pherk(const char* uplostr, const char* transstr, int n, int k, blas:: int64_t Cn = n; // create SLATE matrices from the ScaLAPACK layouts - int nprow, npcol, myprow, mypcol; - Cblacs_gridinfo(desc_CTXT(desca), &nprow, &npcol, &myprow, &mypcol); - auto A = slate::Matrix::fromScaLAPACK(desc_M(desca), desc_N(desca), a, desc_LLD(desca), desc_MB(desca), desc_NB(desca), grid_order, nprow, npcol, MPI_COMM_WORLD); - A = slate_scalapack_submatrix(Am, An, A, ia, ja, desca); - - Cblacs_gridinfo(desc_CTXT(descc), &nprow, &npcol, &myprow, &mypcol); - auto C = slate::HermitianMatrix::fromScaLAPACK(uplo, desc_N(descc), c, desc_LLD(descc), desc_NB(descc), grid_order, nprow, npcol, MPI_COMM_WORLD); - C = slate_scalapack_submatrix(Cm, Cn, C, ic, jc, descc); + blas_int nprow, npcol, myprow, mypcol; + Cblacs_gridinfo( desc_ctxt( descA ), &nprow, &npcol, &myprow, &mypcol ); + auto A = slate::Matrix::fromScaLAPACK( + desc_m( descA ), desc_n( descA ), A_data, desc_lld( descA ), + desc_mb( descA ), desc_nb( descA ), + grid_order, nprow, npcol, MPI_COMM_WORLD ); + A = slate_scalapack_submatrix( Am, An, A, ia, ja, descA ); + + Cblacs_gridinfo( desc_ctxt( descC ), &nprow, &npcol, &myprow, &mypcol ); + auto C = slate::HermitianMatrix::fromScaLAPACK( + uplo, desc_n( descC ), C_data, desc_lld( descC ), desc_nb( descC ), + grid_order, nprow, npcol, MPI_COMM_WORLD ); + C = slate_scalapack_submatrix( Cm, Cn, C, ic, jc, descC ); if (verbose && myprow == 0 && mypcol == 0) logprintf("%s\n", "herk"); if (transA == blas::Op::Trans) - A = transpose(A); + A = transpose( A ); else if (transA == blas::Op::ConjTrans) A = conj_transpose( A ); - assert(A.mt() == C.mt()); + assert( A.mt() == C.mt() ); - slate::herk(alpha, A, beta, C, { + slate::herk( alpha, A, beta, C, { {slate::Option::Lookahead, lookahead}, {slate::Option::Target, target} }); } +//------------------------------------------------------------------------------ +// Fortran interfaces +// Each Fortran interface calls the type generic slate wrapper. + +extern "C" { + +#define SCALAPACK_pcherk BLAS_FORTRAN_NAME( pcherk, PCHERK ) +void SCALAPACK_pcherk( + const char* uplo, const char* trans, + blas_int const* n, blas_int const* k, float* alpha, + std::complex* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + float* beta, + std::complex* C_data, blas_int const* ic, blas_int const* jc, blas_int const* descC ) +{ + slate_pherk( + uplo, trans, *n, *k, *alpha, + A_data, *ia, *ja, descA, *beta, + C_data, *ic, *jc, descC ); +} + +#define SCALAPACK_pzherk BLAS_FORTRAN_NAME( pzherk, PZHERK ) +void SCALAPACK_pzherk( + const char* uplo, const char* trans, + blas_int const* n, blas_int const* k, double* alpha, + std::complex* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + double* beta, + std::complex* C_data, blas_int const* ic, blas_int const* jc, blas_int const* descC ) +{ + slate_pherk( + uplo, trans, *n, *k, *alpha, + A_data, *ia, *ja, descA, *beta, + C_data, *ic, *jc, descC ); +} + +} // extern "C" + } // namespace scalapack_api } // namespace slate diff --git a/scalapack_api/scalapack_lange.cc b/scalapack_api/scalapack_lange.cc index e7123c5f..dde2b89d 100644 --- a/scalapack_api/scalapack_lange.cc +++ b/scalapack_api/scalapack_lange.cc @@ -8,91 +8,139 @@ namespace slate { namespace scalapack_api { -// ----------------------------------------------------------------------------- - -// Required CBLACS calls -extern "C" void Cblacs_gridinfo(int context, int* np_row, int* np_col, int* my_row, int* my_col); - +//------------------------------------------------------------------------------ // Type generic function calls the SLATE routine -template< typename scalar_t > -blas::real_type slate_plange(const char* normstr, int m, int n, scalar_t* a, int ia, int ja, int* desca, blas::real_type* work); - -// ----------------------------------------------------------------------------- -// C interfaces (FORTRAN_UPPER, FORTRAN_LOWER, FORTRAN_UNDERSCORE) -// Each C interface calls the type generic slate_pher2k - -extern "C" float PSLANGE(const char* norm, int* m, int* n, float* a, int* ia, int* ja, int* desca, float* work) +template +blas::real_type slate_plange(const char* norm_str, blas_int m, blas_int n, + scalar_t* A_data, blas_int ia, blas_int ja, blas_int const* descA, + blas::real_type* work); + +//------------------------------------------------------------------------------ +// Fortran interfaces +// Each Fortran interface calls the type generic slate wrapper. + +extern "C" float PSLANGE(const char* norm, blas_int const* m, blas_int const* n, + float* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + float* work) { - return slate_plange(norm, *m, *n, a, *ia, *ja, desca, work); + return slate_plange(norm, *m, *n, + A_data, *ia, *ja, descA, + work); } -extern "C" float pslange(const char* norm, int* m, int* n, float* a, int* ia, int* ja, int* desca, float* work) +extern "C" float pslange(const char* norm, blas_int const* m, blas_int const* n, + float* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + float* work) { - return slate_plange(norm, *m, *n, a, *ia, *ja, desca, work); + return slate_plange(norm, *m, *n, + A_data, *ia, *ja, descA, + work); } -extern "C" float pslange_(const char* norm, int* m, int* n, float* a, int* ia, int* ja, int* desca, float* work) +extern "C" float pslange_(const char* norm, blas_int const* m, blas_int const* n, + float* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + float* work) { - return slate_plange(norm, *m, *n, a, *ia, *ja, desca, work); + return slate_plange(norm, *m, *n, + A_data, *ia, *ja, descA, + work); } -// ----------------------------------------------------------------------------- +//------------------------------------------------------------------------------ -extern "C" double PDLANGE(const char* norm, int* m, int* n, double* a, int* ia, int* ja, int* desca, double* work) +extern "C" double PDLANGE(const char* norm, blas_int const* m, blas_int const* n, + double* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + double* work) { - return slate_plange(norm, *m, *n, a, *ia, *ja, desca, work); + return slate_plange(norm, *m, *n, + A_data, *ia, *ja, descA, + work); } -extern "C" double pdlange(const char* norm, int* m, int* n, double* a, int* ia, int* ja, int* desca, double* work) +extern "C" double pdlange(const char* norm, blas_int const* m, blas_int const* n, + double* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + double* work) { - return slate_plange(norm, *m, *n, a, *ia, *ja, desca, work); + return slate_plange(norm, *m, *n, + A_data, *ia, *ja, descA, + work); } -extern "C" double pdlange_(const char* norm, int* m, int* n, double* a, int* ia, int* ja, int* desca, double* work) +extern "C" double pdlange_(const char* norm, blas_int const* m, blas_int const* n, + double* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + double* work) { - return slate_plange(norm, *m, *n, a, *ia, *ja, desca, work); + return slate_plange(norm, *m, *n, + A_data, *ia, *ja, descA, + work); } -// ----------------------------------------------------------------------------- +//------------------------------------------------------------------------------ -extern "C" float PCLANGE(const char* norm, int* m, int* n, std::complex* a, int* ia, int* ja, int* desca, float* work) +extern "C" float PCLANGE(const char* norm, blas_int const* m, blas_int const* n, + std::complex* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + float* work) { - return slate_plange(norm, *m, *n, a, *ia, *ja, desca, work); + return slate_plange(norm, *m, *n, + A_data, *ia, *ja, descA, + work); } -extern "C" float pclange(const char* norm, int* m, int* n, std::complex* a, int* ia, int* ja, int* desca, float* work) +extern "C" float pclange(const char* norm, blas_int const* m, blas_int const* n, + std::complex* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + float* work) { - return slate_plange(norm, *m, *n, a, *ia, *ja, desca, work); + return slate_plange(norm, *m, *n, + A_data, *ia, *ja, descA, + work); } -extern "C" float pclange_(const char* norm, int* m, int* n, std::complex* a, int* ia, int* ja, int* desca, float* work) +extern "C" float pclange_(const char* norm, blas_int const* m, blas_int const* n, + std::complex* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + float* work) { - return slate_plange(norm, *m, *n, a, *ia, *ja, desca, work); + return slate_plange(norm, *m, *n, + A_data, *ia, *ja, descA, + work); } -// ----------------------------------------------------------------------------- +//------------------------------------------------------------------------------ -extern "C" double PZLANGE(const char* norm, int* m, int* n, std::complex* a, int* ia, int* ja, int* desca, double* work) +extern "C" double PZLANGE(const char* norm, blas_int const* m, blas_int const* n, + std::complex* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + double* work) { - return slate_plange(norm, *m, *n, a, *ia, *ja, desca, work); + return slate_plange(norm, *m, *n, + A_data, *ia, *ja, descA, + work); } -extern "C" double pzlange(const char* norm, int* m, int* n, std::complex* a, int* ia, int* ja, int* desca, double* work) +extern "C" double pzlange(const char* norm, blas_int const* m, blas_int const* n, + std::complex* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + double* work) { - return slate_plange(norm, *m, *n, a, *ia, *ja, desca, work); + return slate_plange(norm, *m, *n, + A_data, *ia, *ja, descA, + work); } -extern "C" double pzlange_(const char* norm, int* m, int* n, std::complex* a, int* ia, int* ja, int* desca, double* work) +extern "C" double pzlange_(const char* norm, blas_int const* m, blas_int const* n, + std::complex* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + double* work) { - return slate_plange(norm, *m, *n, a, *ia, *ja, desca, work); + return slate_plange(norm, *m, *n, + A_data, *ia, *ja, descA, + work); } -// ----------------------------------------------------------------------------- -template< typename scalar_t > -blas::real_type slate_plange(const char* normstr, int m, int n, scalar_t* a, int ia, int ja, int* desca, blas::real_type* work) +//------------------------------------------------------------------------------ +template +blas::real_type slate_plange(const char* norm_str, blas_int m, blas_int n, + scalar_t* A_data, blas_int ia, blas_int ja, blas_int const* descA, + blas::real_type* work) { Norm norm{}; - from_string( std::string( 1, normstr[0] ), &norm ); + from_string( std::string( 1, norm_str[0] ), &norm ); slate::Target target = TargetConfig::value(); int verbose = VerboseConfig::value(); @@ -104,16 +152,19 @@ blas::real_type slate_plange(const char* normstr, int m, int n, scalar int64_t An = n; // create SLATE matrices from the ScaLAPACK layouts - int nprow, npcol, myprow, mypcol; - Cblacs_gridinfo(desc_CTXT(desca), &nprow, &npcol, &myprow, &mypcol); - auto A = slate::Matrix::fromScaLAPACK(desc_M(desca), desc_N(desca), a, desc_LLD(desca), desc_MB(desca), desc_NB(desca), grid_order, nprow, npcol, MPI_COMM_WORLD); - A = slate_scalapack_submatrix(Am, An, A, ia, ja, desca); + blas_int nprow, npcol, myprow, mypcol; + Cblacs_gridinfo( desc_ctxt( descA ), &nprow, &npcol, &myprow, &mypcol ); + auto A = slate::Matrix::fromScaLAPACK( + desc_m( descA ), desc_n( descA ), A_data, desc_lld( descA ), + desc_mb( descA ), desc_nb( descA ), + grid_order, nprow, npcol, MPI_COMM_WORLD ); + A = slate_scalapack_submatrix( Am, An, A, ia, ja, descA ); if (verbose && myprow == 0 && mypcol == 0) logprintf("%s\n", "lange"); blas::real_type A_norm; - A_norm = slate::norm(norm, A, { + A_norm = slate::norm( norm, A, { {slate::Option::Target, target}, {slate::Option::Lookahead, lookahead} }); diff --git a/scalapack_api/scalapack_lanhe.cc b/scalapack_api/scalapack_lanhe.cc index 7194e426..7d8c6a90 100644 --- a/scalapack_api/scalapack_lanhe.cc +++ b/scalapack_api/scalapack_lanhe.cc @@ -8,61 +8,85 @@ namespace slate { namespace scalapack_api { -// ----------------------------------------------------------------------------- - -// Required CBLACS calls -extern "C" void Cblacs_gridinfo(int context, int* np_row, int* np_col, int* my_row, int* my_col); - +//------------------------------------------------------------------------------ // Type generic function calls the SLATE routine -template< typename scalar_t > -blas::real_type slate_planhe(const char* normstr, const char* uplostr, int n, scalar_t* a, int ia, int ja, int* desca, blas::real_type* work); +template +blas::real_type slate_planhe(const char* norm_str, const char* uplo_str, blas_int n, + scalar_t* A_data, blas_int ia, blas_int ja, blas_int const* descA, + blas::real_type* work); -// ----------------------------------------------------------------------------- -// C interfaces (FORTRAN_UPPER, FORTRAN_LOWER, FORTRAN_UNDERSCORE) -// Each C interface calls the type generic slate_pher2k +//------------------------------------------------------------------------------ +// Fortran interfaces +// Each Fortran interface calls the type generic slate wrapper. -// ----------------------------------------------------------------------------- +//------------------------------------------------------------------------------ -extern "C" float PCLANHE(const char* norm, const char* uplo, int* n, std::complex* a, int* ia, int* ja, int* desca, float* work) +extern "C" float PCLANHE(const char* norm, const char* uplo, blas_int const* n, + std::complex* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + float* work) { - return slate_planhe(norm, uplo, *n, a, *ia, *ja, desca, work); + return slate_planhe(norm, uplo, *n, + A_data, *ia, *ja, descA, + work); } -extern "C" float pclanhe(const char* norm, const char* uplo, int* n, std::complex* a, int* ia, int* ja, int* desca, float* work) +extern "C" float pclanhe(const char* norm, const char* uplo, blas_int const* n, + std::complex* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + float* work) { - return slate_planhe(norm, uplo, *n, a, *ia, *ja, desca, work); + return slate_planhe(norm, uplo, *n, + A_data, *ia, *ja, descA, + work); } -extern "C" float pclanhe_(const char* norm, const char* uplo, int* n, std::complex* a, int* ia, int* ja, int* desca, float* work) +extern "C" float pclanhe_(const char* norm, const char* uplo, blas_int const* n, + std::complex* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + float* work) { - return slate_planhe(norm, uplo, *n, a, *ia, *ja, desca, work); + return slate_planhe(norm, uplo, *n, + A_data, *ia, *ja, descA, + work); } -// ----------------------------------------------------------------------------- +//------------------------------------------------------------------------------ -extern "C" double PZLANHE(const char* norm, const char* uplo, int* n, std::complex* a, int* ia, int* ja, int* desca, double* work) +extern "C" double PZLANHE(const char* norm, const char* uplo, blas_int const* n, + std::complex* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + double* work) { - return slate_planhe(norm, uplo, *n, a, *ia, *ja, desca, work); + return slate_planhe(norm, uplo, *n, + A_data, *ia, *ja, descA, + work); } -extern "C" double pzlanhe(const char* norm, const char* uplo, int* n, std::complex* a, int* ia, int* ja, int* desca, double* work) +extern "C" double pzlanhe(const char* norm, const char* uplo, blas_int const* n, + std::complex* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + double* work) { - return slate_planhe(norm, uplo, *n, a, *ia, *ja, desca, work); + return slate_planhe(norm, uplo, *n, + A_data, *ia, *ja, descA, + work); } -extern "C" double pzlanhe_(const char* norm, const char* uplo, int* n, std::complex* a, int* ia, int* ja, int* desca, double* work) +extern "C" double pzlanhe_(const char* norm, const char* uplo, blas_int const* n, + std::complex* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + double* work) { - return slate_planhe(norm, uplo, *n, a, *ia, *ja, desca, work); + return slate_planhe(norm, uplo, *n, + A_data, *ia, *ja, descA, + work); } -// ----------------------------------------------------------------------------- -template< typename scalar_t > -blas::real_type slate_planhe(const char* normstr, const char* uplostr, int n, scalar_t* a, int ia, int ja, int* desca, blas::real_type* work) +//------------------------------------------------------------------------------ +template +blas::real_type slate_planhe(const char* norm_str, const char* uplo_str, blas_int n, + scalar_t* A_data, blas_int ia, blas_int ja, blas_int const* descA, + blas::real_type* work) { Uplo uplo{}; Norm norm{}; - from_string( std::string( 1, uplostr[0] ), &uplo ); - from_string( std::string( 1, normstr[0] ), &norm ); + from_string( std::string( 1, uplo_str[0] ), &uplo ); + from_string( std::string( 1, norm_str[0] ), &norm ); slate::Target target = TargetConfig::value(); int verbose = VerboseConfig::value(); @@ -74,16 +98,18 @@ blas::real_type slate_planhe(const char* normstr, const char* uplostr, int64_t An = n; // create SLATE matrices from the ScaLAPACK layouts - int nprow, npcol, myprow, mypcol; - Cblacs_gridinfo(desc_CTXT(desca), &nprow, &npcol, &myprow, &mypcol); - auto A = slate::HermitianMatrix::fromScaLAPACK(uplo, desc_N(desca), a, desc_LLD(desca), desc_NB(desca), grid_order, nprow, npcol, MPI_COMM_WORLD); - A = slate_scalapack_submatrix(Am, An, A, ia, ja, desca); + blas_int nprow, npcol, myprow, mypcol; + Cblacs_gridinfo( desc_ctxt( descA ), &nprow, &npcol, &myprow, &mypcol ); + auto A = slate::HermitianMatrix::fromScaLAPACK( + uplo, desc_n( descA ), A_data, desc_lld( descA ), desc_nb( descA ), + grid_order, nprow, npcol, MPI_COMM_WORLD ); + A = slate_scalapack_submatrix( Am, An, A, ia, ja, descA ); if (verbose && myprow == 0 && mypcol == 0) logprintf("%s\n", "lanhe"); blas::real_type A_norm = 1.0; - A_norm = slate::norm(norm, A, { + A_norm = slate::norm( norm, A, { {slate::Option::Target, target}, {slate::Option::Lookahead, lookahead} }); diff --git a/scalapack_api/scalapack_lansy.cc b/scalapack_api/scalapack_lansy.cc index b21b5b81..88fc6d32 100644 --- a/scalapack_api/scalapack_lansy.cc +++ b/scalapack_api/scalapack_lansy.cc @@ -8,93 +8,141 @@ namespace slate { namespace scalapack_api { -// ----------------------------------------------------------------------------- - -// Required CBLACS calls -extern "C" void Cblacs_gridinfo(int context, int* np_row, int* np_col, int* my_row, int* my_col); - +//------------------------------------------------------------------------------ // Type generic function calls the SLATE routine -template< typename scalar_t > -blas::real_type slate_plansy(const char* normstr, const char* uplostr, int n, scalar_t* a, int ia, int ja, int* desca, blas::real_type* work); - -// ----------------------------------------------------------------------------- -// C interfaces (FORTRAN_UPPER, FORTRAN_LOWER, FORTRAN_UNDERSCORE) -// Each C interface calls the type generic slate_pher2k - -extern "C" float PSLANSY(const char* norm, const char* uplo, int* n, float* a, int* ia, int* ja, int* desca, float* work) +template +blas::real_type slate_plansy(const char* norm_str, const char* uplo_str, blas_int n, + scalar_t* A_data, blas_int ia, blas_int ja, blas_int const* descA, + blas::real_type* work); + +//------------------------------------------------------------------------------ +// Fortran interfaces +// Each Fortran interface calls the type generic slate wrapper. + +extern "C" float PSLANSY(const char* norm, const char* uplo, blas_int const* n, + float* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + float* work) { - return slate_plansy(norm, uplo, *n, a, *ia, *ja, desca, work); + return slate_plansy(norm, uplo, *n, + A_data, *ia, *ja, descA, + work); } -extern "C" float pslansy(const char* norm, const char* uplo, int* n, float* a, int* ia, int* ja, int* desca, float* work) +extern "C" float pslansy(const char* norm, const char* uplo, blas_int const* n, + float* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + float* work) { - return slate_plansy(norm, uplo, *n, a, *ia, *ja, desca, work); + return slate_plansy(norm, uplo, *n, + A_data, *ia, *ja, descA, + work); } -extern "C" float pslansy_(const char* norm, const char* uplo, int* n, float* a, int* ia, int* ja, int* desca, float* work) +extern "C" float pslansy_(const char* norm, const char* uplo, blas_int const* n, + float* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + float* work) { - return slate_plansy(norm, uplo, *n, a, *ia, *ja, desca, work); + return slate_plansy(norm, uplo, *n, + A_data, *ia, *ja, descA, + work); } -// ----------------------------------------------------------------------------- +//------------------------------------------------------------------------------ -extern "C" double PDLANSY(const char* norm, const char* uplo, int* n, double* a, int* ia, int* ja, int* desca, double* work) +extern "C" double PDLANSY(const char* norm, const char* uplo, blas_int const* n, + double* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + double* work) { - return slate_plansy(norm, uplo, *n, a, *ia, *ja, desca, work); + return slate_plansy(norm, uplo, *n, + A_data, *ia, *ja, descA, + work); } -extern "C" double pdlansy(const char* norm, const char* uplo, int* n, double* a, int* ia, int* ja, int* desca, double* work) +extern "C" double pdlansy(const char* norm, const char* uplo, blas_int const* n, + double* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + double* work) { - return slate_plansy(norm, uplo, *n, a, *ia, *ja, desca, work); + return slate_plansy(norm, uplo, *n, + A_data, *ia, *ja, descA, + work); } -extern "C" double pdlansy_(const char* norm, const char* uplo, int* n, double* a, int* ia, int* ja, int* desca, double* work) +extern "C" double pdlansy_(const char* norm, const char* uplo, blas_int const* n, + double* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + double* work) { - return slate_plansy(norm, uplo, *n, a, *ia, *ja, desca, work); + return slate_plansy(norm, uplo, *n, + A_data, *ia, *ja, descA, + work); } -// ----------------------------------------------------------------------------- +//------------------------------------------------------------------------------ -extern "C" float PCLANSY(const char* norm, const char* uplo, int* n, std::complex* a, int* ia, int* ja, int* desca, float* work) +extern "C" float PCLANSY(const char* norm, const char* uplo, blas_int const* n, + std::complex* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + float* work) { - return slate_plansy(norm, uplo, *n, a, *ia, *ja, desca, work); + return slate_plansy(norm, uplo, *n, + A_data, *ia, *ja, descA, + work); } -extern "C" float pclansy(const char* norm, const char* uplo, int* n, std::complex* a, int* ia, int* ja, int* desca, float* work) +extern "C" float pclansy(const char* norm, const char* uplo, blas_int const* n, + std::complex* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + float* work) { - return slate_plansy(norm, uplo, *n, a, *ia, *ja, desca, work); + return slate_plansy(norm, uplo, *n, + A_data, *ia, *ja, descA, + work); } -extern "C" float pclansy_(const char* norm, const char* uplo, int* n, std::complex* a, int* ia, int* ja, int* desca, float* work) +extern "C" float pclansy_(const char* norm, const char* uplo, blas_int const* n, + std::complex* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + float* work) { - return slate_plansy(norm, uplo, *n, a, *ia, *ja, desca, work); + return slate_plansy(norm, uplo, *n, + A_data, *ia, *ja, descA, + work); } -// ----------------------------------------------------------------------------- +//------------------------------------------------------------------------------ -extern "C" double PZLANSY(const char* norm, const char* uplo, int* n, std::complex* a, int* ia, int* ja, int* desca, double* work) +extern "C" double PZLANSY(const char* norm, const char* uplo, blas_int const* n, + std::complex* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + double* work) { - return slate_plansy(norm, uplo, *n, a, *ia, *ja, desca, work); + return slate_plansy(norm, uplo, *n, + A_data, *ia, *ja, descA, + work); } -extern "C" double pzlansy(const char* norm, const char* uplo, int* n, std::complex* a, int* ia, int* ja, int* desca, double* work) +extern "C" double pzlansy(const char* norm, const char* uplo, blas_int const* n, + std::complex* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + double* work) { - return slate_plansy(norm, uplo, *n, a, *ia, *ja, desca, work); + return slate_plansy(norm, uplo, *n, + A_data, *ia, *ja, descA, + work); } -extern "C" double pzlansy_(const char* norm, const char* uplo, int* n, std::complex* a, int* ia, int* ja, int* desca, double* work) +extern "C" double pzlansy_(const char* norm, const char* uplo, blas_int const* n, + std::complex* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + double* work) { - return slate_plansy(norm, uplo, *n, a, *ia, *ja, desca, work); + return slate_plansy(norm, uplo, *n, + A_data, *ia, *ja, descA, + work); } -// ----------------------------------------------------------------------------- -template< typename scalar_t > -blas::real_type slate_plansy(const char* normstr, const char* uplostr, int n, scalar_t* a, int ia, int ja, int* desca, blas::real_type* work) +//------------------------------------------------------------------------------ +template +blas::real_type slate_plansy(const char* norm_str, const char* uplo_str, blas_int n, + scalar_t* A_data, blas_int ia, blas_int ja, blas_int const* descA, + blas::real_type* work) { Uplo uplo{}; Norm norm{}; - from_string( std::string( 1, uplostr[0] ), &uplo ); - from_string( std::string( 1, normstr[0] ), &norm ); + from_string( std::string( 1, uplo_str[0] ), &uplo ); + from_string( std::string( 1, norm_str[0] ), &norm ); slate::Target target = TargetConfig::value(); int verbose = VerboseConfig::value(); @@ -106,16 +154,18 @@ blas::real_type slate_plansy(const char* normstr, const char* uplostr, int64_t An = n; // create SLATE matrices from the ScaLAPACK layouts - int nprow, npcol, myprow, mypcol; - Cblacs_gridinfo(desc_CTXT(desca), &nprow, &npcol, &myprow, &mypcol); - auto A = slate::SymmetricMatrix::fromScaLAPACK(uplo, desc_N(desca), a, desc_LLD(desca), desc_NB(desca), grid_order, nprow, npcol, MPI_COMM_WORLD); - A = slate_scalapack_submatrix(Am, An, A, ia, ja, desca); + blas_int nprow, npcol, myprow, mypcol; + Cblacs_gridinfo( desc_ctxt( descA ), &nprow, &npcol, &myprow, &mypcol ); + auto A = slate::SymmetricMatrix::fromScaLAPACK( + uplo, desc_n( descA ), A_data, desc_lld( descA ), desc_nb( descA ), + grid_order, nprow, npcol, MPI_COMM_WORLD ); + A = slate_scalapack_submatrix( Am, An, A, ia, ja, descA ); if (verbose && myprow == 0 && mypcol == 0) logprintf("%s\n", "lansy"); blas::real_type A_norm; - A_norm = slate::norm(norm, A, { + A_norm = slate::norm( norm, A, { {slate::Option::Target, target}, {slate::Option::Lookahead, lookahead} }); diff --git a/scalapack_api/scalapack_lantr.cc b/scalapack_api/scalapack_lantr.cc index e8945c6f..6b8dfaa3 100644 --- a/scalapack_api/scalapack_lantr.cc +++ b/scalapack_api/scalapack_lantr.cc @@ -11,95 +11,143 @@ namespace slate { namespace scalapack_api { -// ----------------------------------------------------------------------------- - -// Required CBLACS calls -extern "C" void Cblacs_gridinfo(int context, int* np_row, int* np_col, int* my_row, int* my_col); - +//------------------------------------------------------------------------------ // Type generic function calls the SLATE routine -template< typename scalar_t > -blas::real_type slate_plantr(const char* normstr, const char* uplostr, const char* diagstr, int m, int n, scalar_t* a, int ia, int ja, int* desca, blas::real_type* work); - -// ----------------------------------------------------------------------------- -// C interfaces (FORTRAN_UPPER, FORTRAN_LOWER, FORTRAN_UNDERSCORE) -// Each C interface calls the type generic slate_pher2k - -extern "C" float PSLANTR(const char* norm, const char* uplo, const char* diag, int* m, int* n, float* a, int* ia, int* ja, int* desca, float* work) +template +blas::real_type slate_plantr(const char* norm_str, const char* uplo_str, const char* diag_str, blas_int m, blas_int n, + scalar_t* A_data, blas_int ia, blas_int ja, blas_int const* descA, + blas::real_type* work); + +//------------------------------------------------------------------------------ +// Fortran interfaces +// Each Fortran interface calls the type generic slate wrapper. + +extern "C" float PSLANTR(const char* norm, const char* uplo, const char* diag, blas_int const* m, blas_int const* n, + float* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + float* work) { - return slate_plantr(norm, uplo, diag, *m, *n, a, *ia, *ja, desca, work); + return slate_plantr(norm, uplo, diag, *m, *n, + A_data, *ia, *ja, descA, + work); } -extern "C" float pslantr(const char* norm, const char* uplo, const char* diag, int* m, int* n, float* a, int* ia, int* ja, int* desca, float* work) +extern "C" float pslantr(const char* norm, const char* uplo, const char* diag, blas_int const* m, blas_int const* n, + float* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + float* work) { - return slate_plantr(norm, uplo, diag, *m, *n, a, *ia, *ja, desca, work); + return slate_plantr(norm, uplo, diag, *m, *n, + A_data, *ia, *ja, descA, + work); } -extern "C" float pslantr_(const char* norm, const char* uplo, const char* diag, int* m, int* n, float* a, int* ia, int* ja, int* desca, float* work) +extern "C" float pslantr_(const char* norm, const char* uplo, const char* diag, blas_int const* m, blas_int const* n, + float* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + float* work) { - return slate_plantr(norm, uplo, diag, *m, *n, a, *ia, *ja, desca, work); + return slate_plantr(norm, uplo, diag, *m, *n, + A_data, *ia, *ja, descA, + work); } -// ----------------------------------------------------------------------------- +//------------------------------------------------------------------------------ -extern "C" double PDLANTR(const char* norm, const char* uplo, const char* diag, int* m, int* n, double* a, int* ia, int* ja, int* desca, double* work) +extern "C" double PDLANTR(const char* norm, const char* uplo, const char* diag, blas_int const* m, blas_int const* n, + double* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + double* work) { - return slate_plantr(norm, uplo, diag, *m, *n, a, *ia, *ja, desca, work); + return slate_plantr(norm, uplo, diag, *m, *n, + A_data, *ia, *ja, descA, + work); } -extern "C" double pdlantr(const char* norm, const char* uplo, const char* diag, int* m, int* n, double* a, int* ia, int* ja, int* desca, double* work) +extern "C" double pdlantr(const char* norm, const char* uplo, const char* diag, blas_int const* m, blas_int const* n, + double* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + double* work) { - return slate_plantr(norm, uplo, diag, *m, *n, a, *ia, *ja, desca, work); + return slate_plantr(norm, uplo, diag, *m, *n, + A_data, *ia, *ja, descA, + work); } -extern "C" double pdlantr_(const char* norm, const char* uplo, const char* diag, int* m, int* n, double* a, int* ia, int* ja, int* desca, double* work) +extern "C" double pdlantr_(const char* norm, const char* uplo, const char* diag, blas_int const* m, blas_int const* n, + double* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + double* work) { - return slate_plantr(norm, uplo, diag, *m, *n, a, *ia, *ja, desca, work); + return slate_plantr(norm, uplo, diag, *m, *n, + A_data, *ia, *ja, descA, + work); } -// ----------------------------------------------------------------------------- +//------------------------------------------------------------------------------ -extern "C" float PCLANTR(const char* norm, const char* uplo, const char* diag, int* m, int* n, std::complex* a, int* ia, int* ja, int* desca, float* work) +extern "C" float PCLANTR(const char* norm, const char* uplo, const char* diag, blas_int const* m, blas_int const* n, + std::complex* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + float* work) { - return slate_plantr(norm, uplo, diag, *m, *n, a, *ia, *ja, desca, work); + return slate_plantr(norm, uplo, diag, *m, *n, + A_data, *ia, *ja, descA, + work); } -extern "C" float pclantr(const char* norm, const char* uplo, const char* diag, int* m, int* n, std::complex* a, int* ia, int* ja, int* desca, float* work) +extern "C" float pclantr(const char* norm, const char* uplo, const char* diag, blas_int const* m, blas_int const* n, + std::complex* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + float* work) { - return slate_plantr(norm, uplo, diag, *m, *n, a, *ia, *ja, desca, work); + return slate_plantr(norm, uplo, diag, *m, *n, + A_data, *ia, *ja, descA, + work); } -extern "C" float pclantr_(const char* norm, const char* uplo, const char* diag, int* m, int* n, std::complex* a, int* ia, int* ja, int* desca, float* work) +extern "C" float pclantr_(const char* norm, const char* uplo, const char* diag, blas_int const* m, blas_int const* n, + std::complex* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + float* work) { - return slate_plantr(norm, uplo, diag, *m, *n, a, *ia, *ja, desca, work); + return slate_plantr(norm, uplo, diag, *m, *n, + A_data, *ia, *ja, descA, + work); } -// ----------------------------------------------------------------------------- +//------------------------------------------------------------------------------ -extern "C" double PZLANTR(const char* norm, const char* uplo, const char* diag, int* m, int* n, std::complex* a, int* ia, int* ja, int* desca, double* work) +extern "C" double PZLANTR(const char* norm, const char* uplo, const char* diag, blas_int const* m, blas_int const* n, + std::complex* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + double* work) { - return slate_plantr(norm, uplo, diag, *m, *n, a, *ia, *ja, desca, work); + return slate_plantr(norm, uplo, diag, *m, *n, + A_data, *ia, *ja, descA, + work); } -extern "C" double pzlantr(const char* norm, const char* uplo, const char* diag, int* m, int* n, std::complex* a, int* ia, int* ja, int* desca, double* work) +extern "C" double pzlantr(const char* norm, const char* uplo, const char* diag, blas_int const* m, blas_int const* n, + std::complex* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + double* work) { - return slate_plantr(norm, uplo, diag, *m, *n, a, *ia, *ja, desca, work); + return slate_plantr(norm, uplo, diag, *m, *n, + A_data, *ia, *ja, descA, + work); } -extern "C" double pzlantr_(const char* norm, const char* uplo, const char* diag, int* m, int* n, std::complex* a, int* ia, int* ja, int* desca, double* work) +extern "C" double pzlantr_(const char* norm, const char* uplo, const char* diag, blas_int const* m, blas_int const* n, + std::complex* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + double* work) { - return slate_plantr(norm, uplo, diag, *m, *n, a, *ia, *ja, desca, work); + return slate_plantr(norm, uplo, diag, *m, *n, + A_data, *ia, *ja, descA, + work); } -// ----------------------------------------------------------------------------- -template< typename scalar_t > -blas::real_type slate_plantr(const char* normstr, const char* uplostr, const char* diagstr, int m, int n, scalar_t* a, int ia, int ja, int* desca, blas::real_type* work) +//------------------------------------------------------------------------------ +template +blas::real_type slate_plantr(const char* norm_str, const char* uplo_str, const char* diag_str, blas_int m, blas_int n, + scalar_t* A_data, blas_int ia, blas_int ja, blas_int const* descA, + blas::real_type* work) { Norm norm{}; Uplo uplo{}; Diag diag{}; - from_string( std::string( 1, normstr[0] ), &norm ); - from_string( std::string( 1, uplostr[0] ), &uplo ); - from_string( std::string( 1, diagstr[0] ), &diag ); + from_string( std::string( 1, norm_str[0] ), &norm ); + from_string( std::string( 1, uplo_str[0] ), &uplo ); + from_string( std::string( 1, diag_str[0] ), &diag ); slate::Target target = TargetConfig::value(); int verbose = VerboseConfig::value(); @@ -111,16 +159,18 @@ blas::real_type slate_plantr(const char* normstr, const char* uplostr, int64_t An = n; // create SLATE matrices from the ScaLAPACK layouts - int nprow, npcol, myprow, mypcol; - Cblacs_gridinfo(desc_CTXT(desca), &nprow, &npcol, &myprow, &mypcol); - auto A = slate::TrapezoidMatrix::fromScaLAPACK(uplo, diag, desc_M(desca), desc_N(desca), a, desc_LLD(desca), desc_NB(desca), grid_order, nprow, npcol, MPI_COMM_WORLD); - A = slate_scalapack_submatrix(Am, An, A, ia, ja, desca); + blas_int nprow, npcol, myprow, mypcol; + Cblacs_gridinfo( desc_ctxt( descA ), &nprow, &npcol, &myprow, &mypcol ); + auto A = slate::TrapezoidMatrix::fromScaLAPACK( + uplo, diag, desc_m( descA ), desc_n( descA ), A_data, desc_lld( descA ), desc_nb( descA ), + grid_order, nprow, npcol, MPI_COMM_WORLD ); + A = slate_scalapack_submatrix( Am, An, A, ia, ja, descA ); if (verbose && myprow == 0 && mypcol == 0) - logprintf("%s target %d\n", "lantr", (int)target); + logprintf("%s target %d\n", "lantr", (blas_int)target); blas::real_type A_norm; - A_norm = slate::norm(norm, A, { + A_norm = slate::norm( norm, A, { {slate::Option::Target, target}, {slate::Option::Lookahead, lookahead} }); diff --git a/scalapack_api/scalapack_pocon.cc b/scalapack_api/scalapack_pocon.cc index f5c42a09..799a1e7b 100644 --- a/scalapack_api/scalapack_pocon.cc +++ b/scalapack_api/scalapack_pocon.cc @@ -8,91 +8,22 @@ namespace slate { namespace scalapack_api { -// ----------------------------------------------------------------------------- - -// Required CBLACS calls -extern "C" void Cblacs_gridinfo(int context, int* np_row, int* np_col, int* my_row, int* my_col); - -// Type generic function calls the SLATE routine -template< typename scalar_t > -void slate_ppocon(const char* uplostr, int n, scalar_t* a, int ia, int ja, int* desca, blas::real_type anorm, blas::real_type* rcond, scalar_t* work, int lwork, void* irwork, int lirwork, int* info); - -// ----------------------------------------------------------------------------- -// C interfaces (FORTRAN_UPPER, FORTRAN_LOWER, FORTRAN_UNDERSCORE) -// Each C interface calls the type generic slate_pher2k - -extern "C" void PSPOCON(const char* uplostr, int* n, float* a, int* ia, int* ja, int* desca, float* anorm, float* rcond, float* work, int* lwork, int* iwork, int* liwork, int* info) -{ - slate_ppocon(uplostr, *n, a, *ia, *ja, desca, *anorm, rcond, work, *lwork, iwork, *liwork, info); -} - -extern "C" void pspocon(const char* uplostr, int* n, float* a, int* ia, int* ja, int* desca, float* anorm, float* rcond, float* work, int* lwork, int* iwork, int* liwork, int* info) -{ - slate_ppocon(uplostr, *n, a, *ia, *ja, desca, *anorm, rcond, work, *lwork, iwork, *liwork, info); -} - -extern "C" void pspocon_(const char* uplostr, int* n, float* a, int* ia, int* ja, int* desca, float* anorm, float* rcond, float* work, int* lwork, int* iwork, int* liwork, int* info) -{ - slate_ppocon(uplostr, *n, a, *ia, *ja, desca, *anorm, rcond, work, *lwork, iwork, *liwork, info); -} - -// ----------------------------------------------------------------------------- - -extern "C" void PDPOCON(const char* uplostr, int* n, double* a, int* ia, int* ja, int* desca, double* anorm, double* rcond, double* work, int* lwork, int* iwork, int* liwork, int* info) -{ - slate_ppocon(uplostr, *n, a, *ia, *ja, desca, *anorm, rcond, work, *lwork, iwork, *liwork, info); -} - -extern "C" void pdpocon(const char* uplostr, int* n, double* a, int* ia, int* ja, int* desca, double* anorm, double* rcond, double* work, int* lwork, int* iwork, int* liwork, int* info) -{ - slate_ppocon(uplostr, *n, a, *ia, *ja, desca, *anorm, rcond, work, *lwork, iwork, *liwork, info); -} - -extern "C" void pdpocon_(const char* uplostr, int* n, double* a, int* ia, int* ja, int* desca, double* anorm, double* rcond, double* work, int* lwork, int* iwork, int* liwork, int* info) -{ - slate_ppocon(uplostr, *n, a, *ia, *ja, desca, *anorm, rcond, work, *lwork, iwork, *liwork, info); -} - -// ----------------------------------------------------------------------------- - -extern "C" void PCPOCON(const char* uplostr, int* n, std::complex* a, int* ia, int* ja, int* desca, float* anorm, float* rcond, std::complex* work, int* lwork, float* rwork, int* lrwork, int* info) -{ - slate_ppocon(uplostr, *n, a, *ia, *ja, desca, *anorm, rcond, work, *lwork, rwork, *lrwork, info); -} - -extern "C" void pcpocon(const char* uplostr, int* n, std::complex* a, int* ia, int* ja, int* desca, float* anorm, float* rcond, std::complex* work, int* lwork, float* rwork, int* lrwork, int* info) -{ - slate_ppocon(uplostr, *n, a, *ia, *ja, desca, *anorm, rcond, work, *lwork, rwork, *lrwork, info); -} - -extern "C" void pcpocon_(const char* uplostr, int* n, std::complex* a, int* ia, int* ja, int* desca, float* anorm, float* rcond, std::complex* work, int* lwork, float* rwork, int* lrwork, int* info) -{ - slate_ppocon(uplostr, *n, a, *ia, *ja, desca, *anorm, rcond, work, *lwork, rwork, *lrwork, info); -} - -// ----------------------------------------------------------------------------- - -extern "C" void PZPOCON(const char* uplostr, int* n, std::complex* a, int* ia, int* ja, int* desca, double* anorm, double* rcond, std::complex* work, int* lwork, double* rwork, int* lrwork, int* info) -{ - slate_ppocon(uplostr, *n, a, *ia, *ja, desca, *anorm, rcond, work, *lwork, rwork, *lrwork, info); -} - -extern "C" void pzpocon(const char* uplostr, int* n, std::complex* a, int* ia, int* ja, int* desca, double* anorm, double* rcond, std::complex* work, int* lwork, double* rwork, int* lrwork, int* info) -{ - slate_ppocon(uplostr, *n, a, *ia, *ja, desca, *anorm, rcond, work, *lwork, rwork, *lrwork, info); -} - -extern "C" void pzpocon_(const char* uplostr, int* n, std::complex* a, int* ia, int* ja, int* desca, double* anorm, double* rcond, std::complex* work, int* lwork, double* rwork, int* lrwork, int* info) -{ - slate_ppocon(uplostr, *n, a, *ia, *ja, desca, *anorm, rcond, work, *lwork, rwork, *lrwork, info); -} - -// ----------------------------------------------------------------------------- -template< typename scalar_t > -void slate_ppocon(const char* uplostr, int n, scalar_t* a, int ia, int ja, int* desca, blas::real_type anorm, blas::real_type* rcond, scalar_t* work, int lwork, void* irwork, int lirwork, int* info) +//------------------------------------------------------------------------------ +/// SLATE ScaLAPACK wrapper sets up SLATE matrices from ScaLAPACK descriptors +/// and calls SLATE. +/// If scalar_t is real, irwork is integer. +/// If scalar_t is complex, irwork is real. +template +void slate_ppocon( + const char* uplo_str, blas_int n, + scalar_t* A_data, blas_int ia, blas_int ja, blas_int const* descA, + blas::real_type Anorm, blas::real_type* rcond, + scalar_t* work, blas_int lwork, + void* irwork, blas_int lirwork, + blas_int* info ) { Uplo uplo{}; - from_string( std::string( 1, uplostr[0] ), &uplo ); + from_string( std::string( 1, uplo_str[0] ), &uplo ); slate::Target target = TargetConfig::value(); int verbose = VerboseConfig::value(); @@ -104,15 +35,15 @@ void slate_ppocon(const char* uplostr, int n, scalar_t* a, int ia, int ja, int* // todo: extract the real info from getrf *info = 0; - int nprow, npcol, myprow, mypcol; - Cblacs_gridinfo(desc_CTXT(desca), &nprow, &npcol, &myprow, &mypcol); + blas_int nprow, npcol, myprow, mypcol; + Cblacs_gridinfo( desc_ctxt( descA ), &nprow, &npcol, &myprow, &mypcol ); if (verbose && myprow == 0 && mypcol == 0) logprintf("%s\n", "pocon"); if (lwork == -1 || lirwork == -1) { *work = 0; if constexpr (std::is_same_v>) { - *(int*)irwork = 0; + *(blas_int*)irwork = 0; } else { *(blas::real_type*)irwork = 0; @@ -125,10 +56,12 @@ void slate_ppocon(const char* uplostr, int n, scalar_t* a, int ia, int ja, int* int64_t An = n; // create SLATE matrices from the ScaLAPACK layouts - auto A = slate::HermitianMatrix::fromScaLAPACK(uplo, desc_N(desca), a, desc_LLD(desca), desc_NB(desca), grid_order, nprow, npcol, MPI_COMM_WORLD); - A = slate_scalapack_submatrix(Am, An, A, ia, ja, desca); + auto A = slate::HermitianMatrix::fromScaLAPACK( + uplo, desc_n( descA ), A_data, desc_lld( descA ), desc_nb( descA ), + grid_order, nprow, npcol, MPI_COMM_WORLD ); + A = slate_scalapack_submatrix( Am, An, A, ia, ja, descA ); - *rcond = slate::pocondest(slate::Norm::One, A, anorm, { + *rcond = slate::pocondest( slate::Norm::One, A, Anorm, { {slate::Option::Lookahead, lookahead}, {slate::Option::Target, target}, {slate::Option::MaxPanelThreads, panel_threads}, @@ -136,5 +69,73 @@ void slate_ppocon(const char* uplostr, int n, scalar_t* a, int ia, int ja, int* }); } +//------------------------------------------------------------------------------ +// Fortran interfaces +// Each Fortran interface calls the type generic slate wrapper. + +extern "C" { + +#define SCALAPACK_pspocon BLAS_FORTRAN_NAME( pspocon, PSPOCON ) +void SCALAPACK_pspocon( + const char* uplo, blas_int const* n, + float* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + float* Anorm, float* rcond, + float* work, blas_int const* lwork, + blas_int* iwork, blas_int const* liwork, + blas_int* info ) +{ + slate_ppocon( + uplo, *n, + A_data, *ia, *ja, descA, + *Anorm, rcond, work, *lwork, iwork, *liwork, info ); +} + +#define SCALAPACK_pdpocon BLAS_FORTRAN_NAME( pdpocon, PDPOCON ) +void SCALAPACK_pdpocon( + const char* uplo, blas_int const* n, + double* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + double* Anorm, double* rcond, + double* work, blas_int const* lwork, + blas_int* iwork, blas_int const* liwork, + blas_int* info ) +{ + slate_ppocon( + uplo, *n, + A_data, *ia, *ja, descA, + *Anorm, rcond, work, *lwork, iwork, *liwork, info ); +} + +#define SCALAPACK_pcpocon BLAS_FORTRAN_NAME( pcpocon, PCPOCON ) +void SCALAPACK_pcpocon( + const char* uplo, blas_int const* n, + std::complex* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + float* Anorm, float* rcond, + std::complex* work, blas_int const* lwork, + float* rwork, blas_int const* lrwork, + blas_int* info ) +{ + slate_ppocon( + uplo, *n, + A_data, *ia, *ja, descA, + *Anorm, rcond, work, *lwork, rwork, *lrwork, info ); +} + +#define SCALAPACK_pzpocon BLAS_FORTRAN_NAME( pzpocon, PZPOCON ) +void SCALAPACK_pzpocon( + const char* uplo, blas_int const* n, + std::complex* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + double* Anorm, double* rcond, + std::complex* work, blas_int const* lwork, + double* rwork, blas_int const* lrwork, + blas_int* info ) +{ + slate_ppocon( + uplo, *n, + A_data, *ia, *ja, descA, + *Anorm, rcond, work, *lwork, rwork, *lrwork, info ); +} + +} // extern "C" + } // namespace scalapack_api } // namespace slate diff --git a/scalapack_api/scalapack_posv.cc b/scalapack_api/scalapack_posv.cc index e2759c15..29c186bc 100644 --- a/scalapack_api/scalapack_posv.cc +++ b/scalapack_api/scalapack_posv.cc @@ -11,91 +11,18 @@ namespace slate { namespace scalapack_api { -// ----------------------------------------------------------------------------- - -// Required CBLACS calls -extern "C" void Cblacs_gridinfo(int context, int* np_row, int* np_col, int* my_row, int* my_col); - -// Type generic function calls the SLATE routine -template< typename scalar_t > -void slate_pposv(const char* uplostr, int n, int nrhs, scalar_t* a, int ia, int ja, int* desca, scalar_t* b, int ib, int jb, int* descb, int* info); - -// ----------------------------------------------------------------------------- -// C interfaces (FORTRAN_UPPER, FORTRAN_LOWER, FORTRAN_UNDERSCORE) -// Each C interface calls the type generic slate_pher2k - -extern "C" void PSPOSV(const char* uplo, int* n, int* nrhs, float* a, int* ia, int* ja, int* desca, float* b, int* ib, int* jb, int* descb, int* info) -{ - slate_pposv(uplo, *n, *nrhs, a, *ia, *ja, desca, b, *ib, *jb, descb, info); -} - -extern "C" void psposv(const char* uplo, int* n, int* nrhs, float* a, int* ia, int* ja, int* desca, float* b, int* ib, int* jb, int* descb, int* info) -{ - slate_pposv(uplo, *n, *nrhs, a, *ia, *ja, desca, b, *ib, *jb, descb, info); -} - -extern "C" void psposv_(const char* uplo, int* n, int* nrhs, float* a, int* ia, int* ja, int* desca, float* b, int* ib, int* jb, int* descb, int* info) -{ - slate_pposv(uplo, *n, *nrhs, a, *ia, *ja, desca, b, *ib, *jb, descb, info); -} - -// ----------------------------------------------------------------------------- - -extern "C" void PDPOSV(const char* uplo, int* n, int* nrhs, double* a, int* ia, int* ja, int* desca, double* b, int* ib, int* jb, int* descb, int* info) -{ - slate_pposv(uplo, *n, *nrhs, a, *ia, *ja, desca, b, *ib, *jb, descb, info); -} - -extern "C" void pdposv(const char* uplo, int* n, int* nrhs, double* a, int* ia, int* ja, int* desca, double* b, int* ib, int* jb, int* descb, int* info) -{ - slate_pposv(uplo, *n, *nrhs, a, *ia, *ja, desca, b, *ib, *jb, descb, info); -} - -extern "C" void pdposv_(const char* uplo, int* n, int* nrhs, double* a, int* ia, int* ja, int* desca, double* b, int* ib, int* jb, int* descb, int* info) -{ - slate_pposv(uplo, *n, *nrhs, a, *ia, *ja, desca, b, *ib, *jb, descb, info); -} - -// ----------------------------------------------------------------------------- - -extern "C" void PCPOSV(const char* uplo, int* n, int* nrhs, std::complex* a, int* ia, int* ja, int* desca, std::complex* b, int* ib, int* jb, int* descb, int* info) -{ - slate_pposv(uplo, *n, *nrhs, a, *ia, *ja, desca, b, *ib, *jb, descb, info); -} - -extern "C" void pcposv(const char* uplo, int* n, int* nrhs, std::complex* a, int* ia, int* ja, int* desca, std::complex* b, int* ib, int* jb, int* descb, int* info) -{ - slate_pposv(uplo, *n, *nrhs, a, *ia, *ja, desca, b, *ib, *jb, descb, info); -} - -extern "C" void pcposv_(const char* uplo, int* n, int* nrhs, std::complex* a, int* ia, int* ja, int* desca, std::complex* b, int* ib, int* jb, int* descb, int* info) -{ - slate_pposv(uplo, *n, *nrhs, a, *ia, *ja, desca, b, *ib, *jb, descb, info); -} - -// ----------------------------------------------------------------------------- - -extern "C" void PZPOSV(const char* uplo, int* n, int* nrhs, std::complex* a, int* ia, int* ja, int* desca, std::complex* b, int* ib, int* jb, int* descb, int* info) -{ - slate_pposv(uplo, *n, *nrhs, a, *ia, *ja, desca, b, *ib, *jb, descb, info); -} - -extern "C" void pzposv(const char* uplo, int* n, int* nrhs, std::complex* a, int* ia, int* ja, int* desca, std::complex* b, int* ib, int* jb, int* descb, int* info) -{ - slate_pposv(uplo, *n, *nrhs, a, *ia, *ja, desca, b, *ib, *jb, descb, info); -} - -extern "C" void pzposv_(const char* uplo, int* n, int* nrhs, std::complex* a, int* ia, int* ja, int* desca, std::complex* b, int* ib, int* jb, int* descb, int* info) -{ - slate_pposv(uplo, *n, *nrhs, a, *ia, *ja, desca, b, *ib, *jb, descb, info); -} - -// ----------------------------------------------------------------------------- -template< typename scalar_t > -void slate_pposv(const char* uplostr, int n, int nrhs, scalar_t* a, int ia, int ja, int* desca, scalar_t* b, int ib, int jb, int* descb, int* info) +//------------------------------------------------------------------------------ +/// SLATE ScaLAPACK wrapper sets up SLATE matrices from ScaLAPACK descriptors +/// and calls SLATE. +template +void slate_pposv( + const char* uplo_str, blas_int n, blas_int nrhs, + scalar_t* A_data, blas_int ia, blas_int ja, blas_int const* descA, + scalar_t* B_data, blas_int ib, blas_int jb, blas_int const* descB, + blas_int* info ) { Uplo uplo{}; - from_string( std::string( 1, uplostr[0] ), &uplo ); + from_string( std::string( 1, uplo_str[0] ), &uplo ); slate::Target target = TargetConfig::value(); int verbose = VerboseConfig::value(); @@ -110,19 +37,24 @@ void slate_pposv(const char* uplostr, int n, int nrhs, scalar_t* a, int ia, int slate::Pivots pivots; // create SLATE matrices from the ScaLAPACK layouts - int nprow, npcol, myprow, mypcol; - Cblacs_gridinfo(desc_CTXT(desca), &nprow, &npcol, &myprow, &mypcol); - auto A = slate::HermitianMatrix::fromScaLAPACK(uplo, desc_N(desca), a, desc_LLD(desca), desc_NB(desca), grid_order, nprow, npcol, MPI_COMM_WORLD); - A = slate_scalapack_submatrix(Am, An, A, ia, ja, desca); - - Cblacs_gridinfo(desc_CTXT(descb), &nprow, &npcol, &myprow, &mypcol); - auto B = slate::Matrix::fromScaLAPACK(desc_M(descb), desc_N(descb), b, desc_LLD(descb), desc_MB(descb), desc_NB(descb), grid_order, nprow, npcol, MPI_COMM_WORLD); - B = slate_scalapack_submatrix(Bm, Bn, B, ib, jb, descb); + blas_int nprow, npcol, myprow, mypcol; + Cblacs_gridinfo( desc_ctxt( descA ), &nprow, &npcol, &myprow, &mypcol ); + auto A = slate::HermitianMatrix::fromScaLAPACK( + uplo, desc_n( descA ), A_data, desc_lld( descA ), desc_nb( descA ), + grid_order, nprow, npcol, MPI_COMM_WORLD ); + A = slate_scalapack_submatrix( Am, An, A, ia, ja, descA ); + + Cblacs_gridinfo( desc_ctxt( descB ), &nprow, &npcol, &myprow, &mypcol ); + auto B = slate::Matrix::fromScaLAPACK( + desc_m( descB ), desc_n( descB ), B_data, desc_lld( descB ), + desc_mb( descB ), desc_nb( descB ), + grid_order, nprow, npcol, MPI_COMM_WORLD ); + B = slate_scalapack_submatrix( Bm, Bn, B, ib, jb, descB ); if (verbose && myprow == 0 && mypcol == 0) logprintf("%s\n", "posv"); - slate::posv(A, B, { + slate::posv( A, B, { {slate::Option::Lookahead, lookahead}, {slate::Option::Target, target}, }); @@ -131,5 +63,65 @@ void slate_pposv(const char* uplostr, int n, int nrhs, scalar_t* a, int ia, int *info = 0; } +//------------------------------------------------------------------------------ +// Fortran interfaces +// Each Fortran interface calls the type generic slate wrapper. + +extern "C" { + +#define SCALAPACK_psposv BLAS_FORTRAN_NAME( psposv, PSPOSV ) +void SCALAPACK_psposv( + const char* uplo, blas_int const* n, blas_int* nrhs, + float* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + float* B_data, blas_int const* ib, blas_int const* jb, blas_int const* descB, + blas_int* info ) +{ + slate_pposv( + uplo, *n, *nrhs, + A_data, *ia, *ja, descA, + B_data, *ib, *jb, descB, info ); +} + +#define SCALAPACK_pdposv BLAS_FORTRAN_NAME( pdposv, PDPOSV ) +void SCALAPACK_pdposv( + const char* uplo, blas_int const* n, blas_int* nrhs, + double* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + double* B_data, blas_int const* ib, blas_int const* jb, blas_int const* descB, + blas_int* info ) +{ + slate_pposv( + uplo, *n, *nrhs, + A_data, *ia, *ja, descA, + B_data, *ib, *jb, descB, info ); +} + +#define SCALAPACK_pcposv BLAS_FORTRAN_NAME( pcposv, PCPOSV ) +void SCALAPACK_pcposv( + const char* uplo, blas_int const* n, blas_int* nrhs, + std::complex* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + std::complex* B_data, blas_int const* ib, blas_int const* jb, blas_int const* descB, + blas_int* info ) +{ + slate_pposv( + uplo, *n, *nrhs, + A_data, *ia, *ja, descA, + B_data, *ib, *jb, descB, info ); +} + +#define SCALAPACK_pzposv BLAS_FORTRAN_NAME( pzposv, PZPOSV ) +void SCALAPACK_pzposv( + const char* uplo, blas_int const* n, blas_int* nrhs, + std::complex* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + std::complex* B_data, blas_int const* ib, blas_int const* jb, blas_int const* descB, + blas_int* info ) +{ + slate_pposv( + uplo, *n, *nrhs, + A_data, *ia, *ja, descA, + B_data, *ib, *jb, descB, info ); +} + +} // extern "C" + } // namespace scalapack_api } // namespace slate diff --git a/scalapack_api/scalapack_potrf.cc b/scalapack_api/scalapack_potrf.cc index ad9fc897..2ba84193 100644 --- a/scalapack_api/scalapack_potrf.cc +++ b/scalapack_api/scalapack_potrf.cc @@ -11,91 +11,17 @@ namespace slate { namespace scalapack_api { -// ----------------------------------------------------------------------------- - -// Required CBLACS calls -extern "C" void Cblacs_gridinfo(int context, int* np_row, int* np_col, int* my_row, int* my_col); - -// Type generic function calls the SLATE routine -template< typename scalar_t > -void slate_ppotrf(const char* uplostr, int n, scalar_t* a, int ia, int ja, int* desca, int* info); - -// ----------------------------------------------------------------------------- -// C interfaces (FORTRAN_UPPER, FORTRAN_LOWER, FORTRAN_UNDERSCORE) -// Each C interface calls the type generic slate_pher2k - -extern "C" void PSPOTRF(const char* uplo, int* n, float* a, int* ia, int* ja, int* desca, int* info) -{ - slate_ppotrf(uplo, *n, a, *ia, *ja, desca, info); -} - -extern "C" void pspotrf(const char* uplo, int* n, float* a, int* ia, int* ja, int* desca, int* info) -{ - slate_ppotrf(uplo, *n, a, *ia, *ja, desca, info); -} - -extern "C" void pspotrf_(const char* uplo, int* n, float* a, int* ia, int* ja, int* desca, int* info) -{ - slate_ppotrf(uplo, *n, a, *ia, *ja, desca, info); -} - -// ----------------------------------------------------------------------------- - -extern "C" void PDPOTRF(const char* uplo, int* n, double* a, int* ia, int* ja, int* desca, int* info) -{ - slate_ppotrf(uplo, *n, a, *ia, *ja, desca, info); -} - -extern "C" void pdpotrf(const char* uplo, int* n, double* a, int* ia, int* ja, int* desca, int* info) -{ - slate_ppotrf(uplo, *n, a, *ia, *ja, desca, info); -} - -extern "C" void pdpotrf_(const char* uplo, int* n, double* a, int* ia, int* ja, int* desca, int* info) -{ - slate_ppotrf(uplo, *n, a, *ia, *ja, desca, info); -} - -// ----------------------------------------------------------------------------- - -extern "C" void PCPOTRF(const char* uplo, int* n, std::complex* a, int* ia, int* ja, int* desca, int* info) -{ - slate_ppotrf(uplo, *n, a, *ia, *ja, desca, info); -} - -extern "C" void pcpotrf(const char* uplo, int* n, std::complex* a, int* ia, int* ja, int* desca, int* info) -{ - slate_ppotrf(uplo, *n, a, *ia, *ja, desca, info); -} - -extern "C" void pcpotrf_(const char* uplo, int* n, std::complex* a, int* ia, int* ja, int* desca, int* info) -{ - slate_ppotrf(uplo, *n, a, *ia, *ja, desca, info); -} - -// ----------------------------------------------------------------------------- - -extern "C" void PZPOTRF(const char* uplo, int* n, std::complex* a, int* ia, int* ja, int* desca, int* info) -{ - slate_ppotrf(uplo, *n, a, *ia, *ja, desca, info); -} - -extern "C" void pzpotrf(const char* uplo, int* n, std::complex* a, int* ia, int* ja, int* desca, int* info) -{ - slate_ppotrf(uplo, *n, a, *ia, *ja, desca, info); -} - -extern "C" void pzpotrf_(const char* uplo, int* n, std::complex* a, int* ia, int* ja, int* desca, int* info) -{ - slate_ppotrf(uplo, *n, a, *ia, *ja, desca, info); -} - -// ----------------------------------------------------------------------------- -template< typename scalar_t > -void slate_ppotrf(const char* uplostr, int n, scalar_t* a, int ia, int ja, int* desca, int* info) +//------------------------------------------------------------------------------ +/// SLATE ScaLAPACK wrapper sets up SLATE matrices from ScaLAPACK descriptors +/// and calls SLATE. +template +void slate_ppotrf( + const char* uplo_str, blas_int n, + scalar_t* A_data, blas_int ia, blas_int ja, blas_int const* descA, + blas_int* info ) { Uplo uplo{}; - from_string( std::string( 1, uplostr[0] ), &uplo ); + from_string( std::string( 1, uplo_str[0] ), &uplo ); slate::Target target = TargetConfig::value(); int verbose = VerboseConfig::value(); @@ -106,15 +32,17 @@ void slate_ppotrf(const char* uplostr, int n, scalar_t* a, int ia, int ja, int* int64_t An = n; // create SLATE matrices from the ScaLAPACK layouts - int nprow, npcol, myprow, mypcol; - Cblacs_gridinfo(desc_CTXT(desca), &nprow, &npcol, &myprow, &mypcol); - auto A = slate::HermitianMatrix::fromScaLAPACK(uplo, desc_N(desca), a, desc_LLD(desca), desc_NB(desca), grid_order, nprow, npcol, MPI_COMM_WORLD); - A = slate_scalapack_submatrix(An, An, A, ia, ja, desca); + blas_int nprow, npcol, myprow, mypcol; + Cblacs_gridinfo( desc_ctxt( descA ), &nprow, &npcol, &myprow, &mypcol ); + auto A = slate::HermitianMatrix::fromScaLAPACK( + uplo, desc_n( descA ), A_data, desc_lld( descA ), desc_nb( descA ), + grid_order, nprow, npcol, MPI_COMM_WORLD ); + A = slate_scalapack_submatrix( An, An, A, ia, ja, descA ); if (verbose && myprow == 0 && mypcol == 0) logprintf("%s\n", "potrf"); - slate::potrf(A, { + slate::potrf( A, { {slate::Option::Lookahead, lookahead}, {slate::Option::Target, target} }); @@ -123,5 +51,61 @@ void slate_ppotrf(const char* uplostr, int n, scalar_t* a, int ia, int ja, int* *info = 0; } +//------------------------------------------------------------------------------ +// Fortran interfaces +// Each Fortran interface calls the type generic slate wrapper. + +extern "C" { + +#define SCALAPACK_pspotrf BLAS_FORTRAN_NAME( pspotrf, PSPOTRF ) +void SCALAPACK_pspotrf( + const char* uplo, blas_int const* n, + float* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + blas_int* info ) +{ + slate_ppotrf( + uplo, *n, + A_data, *ia, *ja, descA, + info ); +} + +#define SCALAPACK_pdpotrf BLAS_FORTRAN_NAME( pdpotrf, PDPOTRF ) +void SCALAPACK_pdpotrf( + const char* uplo, blas_int const* n, + double* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + blas_int* info ) +{ + slate_ppotrf( + uplo, *n, + A_data, *ia, *ja, descA, + info ); +} + +#define SCALAPACK_pcpotrf BLAS_FORTRAN_NAME( pcpotrf, PCPOTRF ) +void SCALAPACK_pcpotrf( + const char* uplo, blas_int const* n, + std::complex* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + blas_int* info ) +{ + slate_ppotrf( + uplo, *n, + A_data, *ia, *ja, descA, + info ); +} + +#define SCALAPACK_pzpotrf BLAS_FORTRAN_NAME( pzpotrf, PZPOTRF ) +void SCALAPACK_pzpotrf( + const char* uplo, blas_int const* n, + std::complex* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + blas_int* info ) +{ + slate_ppotrf( + uplo, *n, + A_data, *ia, *ja, descA, + info ); +} + +} // extern "C" + } // namespace scalapack_api } // namespace slate diff --git a/scalapack_api/scalapack_potri.cc b/scalapack_api/scalapack_potri.cc index 682c5b6b..30656e02 100644 --- a/scalapack_api/scalapack_potri.cc +++ b/scalapack_api/scalapack_potri.cc @@ -11,91 +11,17 @@ namespace slate { namespace scalapack_api { -// ----------------------------------------------------------------------------- - -// Required CBLACS calls -extern "C" void Cblacs_gridinfo(int context, int* np_row, int* np_col, int* my_row, int* my_col); - -// Type generic function calls the SLATE routine -template< typename scalar_t > -void slate_ppotri(const char* uplostr, int n, scalar_t* a, int ia, int ja, int* desca, int* info); - -// ----------------------------------------------------------------------------- -// C interfaces (FORTRAN_UPPER, FORTRAN_LOWER, FORTRAN_UNDERSCORE) -// Each C interface calls the type generic slate_pher2k - -extern "C" void PSPOTRI(const char* uplo, int* n, float* a, int* ia, int* ja, int* desca, int* info) -{ - slate_ppotri(uplo, *n, a, *ia, *ja, desca, info); -} - -extern "C" void pspotri(const char* uplo, int* n, float* a, int* ia, int* ja, int* desca, int* info) -{ - slate_ppotri(uplo, *n, a, *ia, *ja, desca, info); -} - -extern "C" void pspotri_(const char* uplo, int* n, float* a, int* ia, int* ja, int* desca, int* info) -{ - slate_ppotri(uplo, *n, a, *ia, *ja, desca, info); -} - -// ----------------------------------------------------------------------------- - -extern "C" void PDPOTRI(const char* uplo, int* n, double* a, int* ia, int* ja, int* desca, int* info) -{ - slate_ppotri(uplo, *n, a, *ia, *ja, desca, info); -} - -extern "C" void pdpotri(const char* uplo, int* n, double* a, int* ia, int* ja, int* desca, int* info) -{ - slate_ppotri(uplo, *n, a, *ia, *ja, desca, info); -} - -extern "C" void pdpotri_(const char* uplo, int* n, double* a, int* ia, int* ja, int* desca, int* info) -{ - slate_ppotri(uplo, *n, a, *ia, *ja, desca, info); -} - -// ----------------------------------------------------------------------------- - -extern "C" void PCPOTRI(const char* uplo, int* n, std::complex* a, int* ia, int* ja, int* desca, int* info) -{ - slate_ppotri(uplo, *n, a, *ia, *ja, desca, info); -} - -extern "C" void pcpotri(const char* uplo, int* n, std::complex* a, int* ia, int* ja, int* desca, int* info) -{ - slate_ppotri(uplo, *n, a, *ia, *ja, desca, info); -} - -extern "C" void pcpotri_(const char* uplo, int* n, std::complex* a, int* ia, int* ja, int* desca, int* info) -{ - slate_ppotri(uplo, *n, a, *ia, *ja, desca, info); -} - -// ----------------------------------------------------------------------------- - -extern "C" void PZPOTRI(const char* uplo, int* n, std::complex* a, int* ia, int* ja, int* desca, int* info) -{ - slate_ppotri(uplo, *n, a, *ia, *ja, desca, info); -} - -extern "C" void pzpotri(const char* uplo, int* n, std::complex* a, int* ia, int* ja, int* desca, int* info) -{ - slate_ppotri(uplo, *n, a, *ia, *ja, desca, info); -} - -extern "C" void pzpotri_(const char* uplo, int* n, std::complex* a, int* ia, int* ja, int* desca, int* info) -{ - slate_ppotri(uplo, *n, a, *ia, *ja, desca, info); -} - -// ----------------------------------------------------------------------------- -template< typename scalar_t > -void slate_ppotri(const char* uplostr, int n, scalar_t* a, int ia, int ja, int* desca, int* info) +//------------------------------------------------------------------------------ +/// SLATE ScaLAPACK wrapper sets up SLATE matrices from ScaLAPACK descriptors +/// and calls SLATE. +template +void slate_ppotri( + const char* uplo_str, blas_int n, + scalar_t* A_data, blas_int ia, blas_int ja, blas_int const* descA, + blas_int* info ) { Uplo uplo{}; - from_string( std::string( 1, uplostr[0] ), &uplo ); + from_string( std::string( 1, uplo_str[0] ), &uplo ); slate::Target target = TargetConfig::value(); int verbose = VerboseConfig::value(); @@ -106,15 +32,17 @@ void slate_ppotri(const char* uplostr, int n, scalar_t* a, int ia, int ja, int* int64_t An = n; // create SLATE matrices from the ScaLAPACK layouts - int nprow, npcol, myprow, mypcol; - Cblacs_gridinfo(desc_CTXT(desca), &nprow, &npcol, &myprow, &mypcol); - auto A = slate::HermitianMatrix::fromScaLAPACK(uplo, desc_N(desca), a, desc_LLD(desca), desc_NB(desca), grid_order, nprow, npcol, MPI_COMM_WORLD); - A = slate_scalapack_submatrix(An, An, A, ia, ja, desca); + blas_int nprow, npcol, myprow, mypcol; + Cblacs_gridinfo( desc_ctxt( descA ), &nprow, &npcol, &myprow, &mypcol ); + auto A = slate::HermitianMatrix::fromScaLAPACK( + uplo, desc_n( descA ), A_data, desc_lld( descA ), desc_nb( descA ), + grid_order, nprow, npcol, MPI_COMM_WORLD ); + A = slate_scalapack_submatrix( An, An, A, ia, ja, descA ); if (verbose && myprow == 0 && mypcol == 0) logprintf("%s\n", "potri"); - slate::potri(A, { + slate::potri( A, { {slate::Option::Lookahead, lookahead}, {slate::Option::Target, target} }); @@ -123,5 +51,61 @@ void slate_ppotri(const char* uplostr, int n, scalar_t* a, int ia, int ja, int* *info = 0; } +//------------------------------------------------------------------------------ +// Fortran interfaces +// Each Fortran interface calls the type generic slate wrapper. + +extern "C" { + +#define SCALAPACK_pspotri BLAS_FORTRAN_NAME( pspotri, PSPOTRI ) +void SCALAPACK_pspotri( + const char* uplo, blas_int const* n, + float* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + blas_int* info ) +{ + slate_ppotri( + uplo, *n, + A_data, *ia, *ja, descA, + info ); +} + +#define SCALAPACK_pdpotri BLAS_FORTRAN_NAME( pdpotri, PDPOTRI ) +void SCALAPACK_pdpotri( + const char* uplo, blas_int const* n, + double* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + blas_int* info ) +{ + slate_ppotri( + uplo, *n, + A_data, *ia, *ja, descA, + info ); +} + +#define SCALAPACK_pcpotri BLAS_FORTRAN_NAME( pcpotri, PCPOTRI ) +void SCALAPACK_pcpotri( + const char* uplo, blas_int const* n, + std::complex* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + blas_int* info ) +{ + slate_ppotri( + uplo, *n, + A_data, *ia, *ja, descA, + info ); +} + +#define SCALAPACK_pzpotri BLAS_FORTRAN_NAME( pzpotri, PZPOTRI ) +void SCALAPACK_pzpotri( + const char* uplo, blas_int const* n, + std::complex* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + blas_int* info ) +{ + slate_ppotri( + uplo, *n, + A_data, *ia, *ja, descA, + info ); +} + +} // extern "C" + } // namespace scalapack_api } // namespace slate diff --git a/scalapack_api/scalapack_potrs.cc b/scalapack_api/scalapack_potrs.cc index 47cce0e5..eb8bc156 100644 --- a/scalapack_api/scalapack_potrs.cc +++ b/scalapack_api/scalapack_potrs.cc @@ -11,119 +11,112 @@ namespace slate { namespace scalapack_api { -// ----------------------------------------------------------------------------- - -// Required CBLACS calls -extern "C" void Cblacs_gridinfo(int context, int* np_row, int* np_col, int* my_row, int* my_col); - -// Type generic function calls the SLATE routine -template< typename scalar_t > -void slate_ppotrs(const char* uplostr, int n, int nrhs, scalar_t* a, int ia, int ja, int* desca, scalar_t* b, int ib, int jb, int* descb, int* info); - -// ----------------------------------------------------------------------------- -// C interfaces (FORTRAN_UPPER, FORTRAN_LOWER, FORTRAN_UNDERSCORE) -// Each C interface calls the type generic slate_pher2k - -extern "C" void PSPOTRS(const char* uplo, int* n, int* nrhs, float* a, int* ia, int* ja, int* desca, float* b, int* ib, int* jb, int* descb, int* info) +//------------------------------------------------------------------------------ +/// SLATE ScaLAPACK wrapper sets up SLATE matrices from ScaLAPACK descriptors +/// and calls SLATE. +template +void slate_ppotrs( + const char* uplo_str, blas_int n, blas_int nrhs, + scalar_t* A_data, blas_int ia, blas_int ja, blas_int const* descA, + scalar_t* B_data, blas_int ib, blas_int jb, blas_int const* descB, + blas_int* info ) { - slate_ppotrs(uplo, *n, *nrhs, a, *ia, *ja, desca, b, *ib, *jb, descb, info); -} + Uplo uplo{}; + from_string( std::string( 1, uplo_str[0] ), &uplo ); -extern "C" void pspotrs(const char* uplo, int* n, int* nrhs, float* a, int* ia, int* ja, int* desca, float* b, int* ib, int* jb, int* descb, int* info) -{ - slate_ppotrs(uplo, *n, *nrhs, a, *ia, *ja, desca, b, *ib, *jb, descb, info); -} + slate::Target target = TargetConfig::value(); + int verbose = VerboseConfig::value(); + int64_t lookahead = LookaheadConfig::value(); + slate::GridOrder grid_order = slate_scalapack_blacs_grid_order(); -extern "C" void pspotrs_(const char* uplo, int* n, int* nrhs, float* a, int* ia, int* ja, int* desca, float* b, int* ib, int* jb, int* descb, int* info) -{ - slate_ppotrs(uplo, *n, *nrhs, a, *ia, *ja, desca, b, *ib, *jb, descb, info); -} + // create SLATE matrices from the ScaLAPACK layouts + blas_int nprow, npcol, myprow, mypcol; + Cblacs_gridinfo( desc_ctxt( descA ), &nprow, &npcol, &myprow, &mypcol ); + auto Afull = slate::Matrix::fromScaLAPACK( + desc_m( descA ), desc_n( descA ), A_data, desc_lld( descA ), + desc_mb( descA ), desc_nb( descA ), + grid_order, nprow, npcol, MPI_COMM_WORLD ); + auto Asub = slate_scalapack_submatrix( n, n, Afull, ia, ja, descA ); + slate::HermitianMatrix A( uplo, Asub ); + + Cblacs_gridinfo( desc_ctxt( descB ), &nprow, &npcol, &myprow, &mypcol ); + auto Bfull = slate::Matrix::fromScaLAPACK( + desc_m( descB ), desc_n( descB ), B_data, desc_lld( descB ), + desc_mb( descB ), desc_nb( descB ), + grid_order, nprow, npcol, MPI_COMM_WORLD ); + slate::Matrix B = slate_scalapack_submatrix( n, nrhs, Bfull, ia, ja, descB ); -// ----------------------------------------------------------------------------- + if (verbose && myprow == 0 && mypcol == 0) + logprintf("%s\n", "potrs"); -extern "C" void PDPOTRS(const char* uplo, int* n, int* nrhs, double* a, int* ia, int* ja, int* desca, double* b, int* ib, int* jb, int* descb, int* info) -{ - slate_ppotrs(uplo, *n, *nrhs, a, *ia, *ja, desca, b, *ib, *jb, descb, info); -} + slate::potrs( A, B, { + {slate::Option::Lookahead, lookahead}, + {slate::Option::Target, target}, + }); -extern "C" void pdpotrs(const char* uplo, int* n, int* nrhs, double* a, int* ia, int* ja, int* desca, double* b, int* ib, int* jb, int* descb, int* info) -{ - slate_ppotrs(uplo, *n, *nrhs, a, *ia, *ja, desca, b, *ib, *jb, descb, info); + // todo: extract the real info + *info = 0; } -extern "C" void pdpotrs_(const char* uplo, int* n, int* nrhs, double* a, int* ia, int* ja, int* desca, double* b, int* ib, int* jb, int* descb, int* info) -{ - slate_ppotrs(uplo, *n, *nrhs, a, *ia, *ja, desca, b, *ib, *jb, descb, info); -} +//------------------------------------------------------------------------------ +// Fortran interfaces +// Each Fortran interface calls the type generic slate wrapper. -// ----------------------------------------------------------------------------- +extern "C" { -extern "C" void PCPOTRS(const char* uplo, int* n, int* nrhs, std::complex* a, int* ia, int* ja, int* desca, std::complex* b, int* ib, int* jb, int* descb, int* info) +#define SCALAPACK_pspotrs BLAS_FORTRAN_NAME( pspotrs, PSPOTRS ) +void SCALAPACK_pspotrs( + const char* uplo, blas_int const* n, blas_int* nrhs, + float* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + float* B_data, blas_int const* ib, blas_int const* jb, blas_int const* descB, + blas_int* info ) { - slate_ppotrs(uplo, *n, *nrhs, a, *ia, *ja, desca, b, *ib, *jb, descb, info); + slate_ppotrs( + uplo, *n, *nrhs, + A_data, *ia, *ja, descA, + B_data, *ib, *jb, descB, info ); } -extern "C" void pcpotrs(const char* uplo, int* n, int* nrhs, std::complex* a, int* ia, int* ja, int* desca, std::complex* b, int* ib, int* jb, int* descb, int* info) +#define SCALAPACK_pdpotrs BLAS_FORTRAN_NAME( pdpotrs, PDPOTRS ) +void SCALAPACK_pdpotrs( + const char* uplo, blas_int const* n, blas_int* nrhs, + double* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + double* B_data, blas_int const* ib, blas_int const* jb, blas_int const* descB, + blas_int* info ) { - slate_ppotrs(uplo, *n, *nrhs, a, *ia, *ja, desca, b, *ib, *jb, descb, info); + slate_ppotrs( + uplo, *n, *nrhs, + A_data, *ia, *ja, descA, + B_data, *ib, *jb, descB, info ); } -extern "C" void pcpotrs_(const char* uplo, int* n, int* nrhs, std::complex* a, int* ia, int* ja, int* desca, std::complex* b, int* ib, int* jb, int* descb, int* info) +#define SCALAPACK_pcpotrs BLAS_FORTRAN_NAME( pcpotrs, PCPOTRS ) +void SCALAPACK_pcpotrs( + const char* uplo, blas_int const* n, blas_int* nrhs, + std::complex* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + std::complex* B_data, blas_int const* ib, blas_int const* jb, blas_int const* descB, + blas_int* info ) { - slate_ppotrs(uplo, *n, *nrhs, a, *ia, *ja, desca, b, *ib, *jb, descb, info); + slate_ppotrs( + uplo, *n, *nrhs, + A_data, *ia, *ja, descA, + B_data, *ib, *jb, descB, info ); } -// ----------------------------------------------------------------------------- - -extern "C" void PZPOTRS(const char* uplo, int* n, int* nrhs, std::complex* a, int* ia, int* ja, int* desca, std::complex* b, int* ib, int* jb, int* descb, int* info) +#define SCALAPACK_pzpotrs BLAS_FORTRAN_NAME( pzpotrs, PZPOTRS ) +void SCALAPACK_pzpotrs( + const char* uplo, blas_int const* n, blas_int* nrhs, + std::complex* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + std::complex* B_data, blas_int const* ib, blas_int const* jb, blas_int const* descB, + blas_int* info ) { - slate_ppotrs(uplo, *n, *nrhs, a, *ia, *ja, desca, b, *ib, *jb, descb, info); + slate_ppotrs( + uplo, *n, *nrhs, + A_data, *ia, *ja, descA, + B_data, *ib, *jb, descB, info ); } -extern "C" void pzpotrs(const char* uplo, int* n, int* nrhs, std::complex* a, int* ia, int* ja, int* desca, std::complex* b, int* ib, int* jb, int* descb, int* info) -{ - slate_ppotrs(uplo, *n, *nrhs, a, *ia, *ja, desca, b, *ib, *jb, descb, info); -} - -extern "C" void pzpotrs_(const char* uplo, int* n, int* nrhs, std::complex* a, int* ia, int* ja, int* desca, std::complex* b, int* ib, int* jb, int* descb, int* info) -{ - slate_ppotrs(uplo, *n, *nrhs, a, *ia, *ja, desca, b, *ib, *jb, descb, info); -} - -// ----------------------------------------------------------------------------- -template< typename scalar_t > -void slate_ppotrs(const char* uplostr, int n, int nrhs, scalar_t* a, int ia, int ja, int* desca, scalar_t* b, int ib, int jb, int* descb, int* info) -{ - Uplo uplo{}; - from_string( std::string( 1, uplostr[0] ), &uplo ); - - slate::Target target = TargetConfig::value(); - int verbose = VerboseConfig::value(); - int64_t lookahead = LookaheadConfig::value(); - slate::GridOrder grid_order = slate_scalapack_blacs_grid_order(); - - // create SLATE matrices from the ScaLAPACK layouts - int nprow, npcol, myprow, mypcol; - Cblacs_gridinfo(desc_CTXT(desca), &nprow, &npcol, &myprow, &mypcol); - auto Afull = slate::Matrix::fromScaLAPACK(desc_M(desca), desc_N(desca), a, desc_LLD(desca), desc_MB(desca), desc_NB(desca), grid_order, nprow, npcol, MPI_COMM_WORLD); - auto Asub = slate_scalapack_submatrix(n, n, Afull, ia, ja, desca); - slate::HermitianMatrix A(uplo, Asub); - - Cblacs_gridinfo(desc_CTXT(descb), &nprow, &npcol, &myprow, &mypcol); - auto Bfull = slate::Matrix::fromScaLAPACK(desc_M(descb), desc_N(descb), b, desc_LLD(descb), desc_MB(descb), desc_NB(descb), grid_order, nprow, npcol, MPI_COMM_WORLD); - slate::Matrix B = slate_scalapack_submatrix(n, nrhs, Bfull, ia, ja, descb); - - if (verbose && myprow == 0 && mypcol == 0) - logprintf("%s\n", "potrs"); - - slate::potrs(A, B, { - {slate::Option::Lookahead, lookahead}, - {slate::Option::Target, target}, - }); - - // todo: extract the real info - *info = 0; -} +} // extern "C" } // namespace scalapack_api } // namespace slate diff --git a/scalapack_api/scalapack_slate.hh b/scalapack_api/scalapack_slate.hh index 1ce11d4d..c33766dc 100644 --- a/scalapack_api/scalapack_slate.hh +++ b/scalapack_api/scalapack_slate.hh @@ -11,59 +11,109 @@ #include "slate/slate.hh" -extern "C" void Cblacs_pinfo(int* mypnum, int* nprocs); -extern "C" void Cblacs_pcoord(int icontxt, int pnum, int* prow, int* pcol); -extern "C" void Cblacs_get(int icontxt, int what, int* val); +//============================================================================== +// Prototypes for BLACS routines. +extern "C" { + +// Get my process number and the number of processes. +void Cblacs_pinfo( blas_int* mypnum, blas_int* nprocs ); + +// Get row and col in 2D process grid for process pnum. +void Cblacs_pcoord( blas_int context, blas_int pnum, + blas_int* prow, blas_int* pcol ); + +// Lookup BLACS information. +void Cblacs_get( blas_int context, blas_int what, blas_int* val ); + +// Get 2D process grid size and my row and col in the grid. +void Cblacs_gridinfo( blas_int context, + blas_int* nprow, blas_int* npcol, + blas_int* myprow, blas_int* mypcol ); -#include +} // extern "C" +//============================================================================== namespace slate { namespace scalapack_api { #define logprintf(fmt, ...) \ - do { fprintf(stderr, "%s:%d %s(): " fmt, __FILE__, __LINE__, __func__, __VA_ARGS__); fflush(0); } while (0) + do { \ + fprintf( stdout, "%s:%d %s(): " fmt, \ + __FILE__, __LINE__, __func__, __VA_ARGS__ ); \ + fflush(0); \ + } while (0) + +enum dtype { + BlockCyclic2D = 1, + BlockCyclic2D_INB = 2, +}; -enum slate_scalapack_dtype {BLOCK_CYCLIC_2D=1, BLOCK_CYCLIC_2D_INB=2}; -enum slate_scalapack_desc {DTYPE_=0, CTXT_, M_, N_, MB_, NB_, RSRC_, CSRC_, LLD_}; -enum slate_scalapack_desc_inb {DTYPE_INB=0, CTXT_INB, M_INB, N_INB, IMB_INB, INB_INB, MB_INB, NB_INB, RSRC_INB, CSRC_INB, LLD_INB}; +enum desc { + DTYPE_ = 0, + CTXT_, + M_, + N_, + MB_, + NB_, + RSRC_, + CSRC_, + LLD_, +}; + +enum desc_inb { + DTYPE_INB = 0, + CTXT_INB, + M_INB, + N_INB, + IMB_INB, + INB_INB, + MB_INB, + NB_INB, + RSRC_INB, + CSRC_INB, + LLD_INB +}; -inline int desc_CTXT(int* desca) +//------------------------------------------------------------------------------ +inline blas_int desc_ctxt( blas_int const* descA ) { - return desca[1]; + return descA[ CTXT_ ]; } -inline int desc_M(int* desca) +inline blas_int desc_m( blas_int const* descA ) { - return desca[2]; + return descA[ M_ ]; } -inline int desc_N(int* desca) +inline blas_int desc_n( blas_int const* descA ) { - return desca[3]; + return descA[ N_ ]; } -inline int desc_MB(int* desca) +inline blas_int desc_mb( blas_int const* descA ) { - return (desca[0] == BLOCK_CYCLIC_2D) ? desca[MB_] : desca[MB_INB]; + return (descA[ DTYPE_ ] == BlockCyclic2D) ? descA[ MB_ ] : descA[ MB_INB ]; } -inline int desc_NB(int* desca) +inline blas_int desc_nb( blas_int const* descA ) { - return (desca[0] == BLOCK_CYCLIC_2D) ? desca[NB_] : desca[NB_INB]; + return (descA[ DTYPE_ ] == BlockCyclic2D) ? descA[ NB_ ] : descA[ NB_INB ]; } -inline int desc_LLD(int* desca) +inline blas_int desc_lld( blas_int const* descA ) { - return (desca[0] == BLOCK_CYCLIC_2D) ? desca[LLD_] : desca[LLD_INB]; + return (descA[ DTYPE_ ] == BlockCyclic2D) ? descA[ LLD_ ] : descA[ LLD_INB ]; } +//------------------------------------------------------------------------------ +/// Determine grid order for default BLACS context. inline slate::GridOrder slate_scalapack_blacs_grid_order() { // if nprocs == 1, the grid layout is irrelevant, all-OK // if nprocs > 1 check the grid location of process-number-1 pnum(1). // if pnum(1) is at grid-coord(0, 1) then grid is col-major // else if pnum(1) is not at grid-coord(0, 1) then grid is row-major - int mypnum, nprocs, prow, pcol, icontxt=-1, imone=-1, izero=0, pnum_1=1; + blas_int mypnum, nprocs, prow, pcol, icontxt=-1, imone=-1, izero=0, pnum_1=1; Cblacs_pinfo( &mypnum, &nprocs ); if (nprocs == 1) // only one process, so col-major grid-layout return slate::GridOrder::Col; @@ -77,65 +127,80 @@ inline slate::GridOrder slate_scalapack_blacs_grid_order() } } -template< typename scalar_t > -inline slate::Matrix slate_scalapack_submatrix(int Am, int An, slate::Matrix& A, int ia, int ja, int* desca) +template +slate::Matrix slate_scalapack_submatrix( + blas_int Am, blas_int An, slate::Matrix& A, + blas_int ia, blas_int ja, blas_int const* descA ) { - // logprintf("Am %d An %d ia %d ja %d desc_MB(desca) %d desc_NB(desca) %d A.m() %ld A.n() %ld \n", Am, An, ia, ja, desc_MB(desca), desc_NB(desca), A.m(), A.n()); - if (ia == 1 && ja == 1 && Am == A.m() && An == A.n()) return A; - assert((ia-1) % desc_MB(desca) == 0); - assert((ja-1) % desc_NB(desca) == 0); - assert(Am % desc_MB(desca) == 0); - assert(An % desc_NB(desca) == 0); - int64_t i1 = (ia-1)/desc_MB(desca); - int64_t i2 = i1 + (Am/desc_MB(desca)) - 1; - int64_t j1 = (ja-1)/desc_NB(desca); - int64_t j2 = j1 + (An/desc_NB(desca)) - 1; - return A.sub(i1, i2, j1, j2); + // logprintf("Am %d An %d ia %d ja %d desc_mb( descA ) %d desc_nb( descA ) %d A.m() %ld A.n() %ld \n", Am, An, ia, ja, desc_mb( descA ), desc_nb( descA ), A.m(), A.n()); + if (ia == 1 && ja == 1 && Am == A.m() && An == A.n()) + return A; + assert( (ia-1) % desc_mb( descA ) == 0 ); + assert( (ja-1) % desc_nb( descA ) == 0 ); + assert( Am % desc_mb( descA ) == 0 ); + assert( An % desc_nb( descA ) == 0 ); + int64_t i1 = (ia-1)/desc_mb( descA ); + int64_t i2 = i1 + (Am/desc_mb( descA )) - 1; + int64_t j1 = (ja-1)/desc_nb( descA ); + int64_t j2 = j1 + (An/desc_nb( descA )) - 1; + return A.sub( i1, i2, j1, j2 ); } -template< typename scalar_t > -inline slate::SymmetricMatrix slate_scalapack_submatrix(int Am, int An, slate::SymmetricMatrix& A, int ia, int ja, int* desca) +template +slate::SymmetricMatrix slate_scalapack_submatrix( + blas_int Am, blas_int An, slate::SymmetricMatrix& A, + blas_int ia, blas_int ja, blas_int const* descA ) { - //logprintf("Am %d An %d ia %d ja %d desc_MB(desca) %d desc_NB(desca) %d \n", Am, An, ia, ja, desc_MB(desca), desc_NB(desca)); - if (ia == 1 && ja == 1 && Am == A.m() && An == A.n()) return A; - assert((ia-1) % desc_MB(desca) == 0); - assert(Am % desc_MB(desca) == 0); - int64_t i1 = (ia-1)/desc_MB(desca); - int64_t i2 = i1 + (Am/desc_MB(desca)) - 1; - return A.sub(i1, i2); + //logprintf("Am %d An %d ia %d ja %d desc_mb( descA ) %d desc_nb( descA ) %d \n", Am, An, ia, ja, desc_mb( descA ), desc_nb( descA )); + if (ia == 1 && ja == 1 && Am == A.m() && An == A.n()) + return A; + assert( (ia-1) % desc_mb( descA ) == 0 ); + assert( Am % desc_mb( descA ) == 0 ); + int64_t i1 = (ia-1)/desc_mb( descA ); + int64_t i2 = i1 + (Am/desc_mb( descA )) - 1; + return A.sub( i1, i2 ); } -template< typename scalar_t > -inline slate::TriangularMatrix slate_scalapack_submatrix(int Am, int An, slate::TriangularMatrix& A, int ia, int ja, int* desca) +template +slate::TriangularMatrix slate_scalapack_submatrix( + blas_int Am, blas_int An, slate::TriangularMatrix& A, + blas_int ia, blas_int ja, blas_int const* descA ) { - if (ia == 1 && ja == 1 && Am == A.m() && An == A.n()) return A; - assert((ia-1) % desc_MB(desca) == 0); - assert(Am % desc_MB(desca) == 0); - int64_t i1 = (ia-1)/desc_MB(desca); - int64_t i2 = i1 + (Am/desc_MB(desca)) - 1; - return A.sub(i1, i2); + if (ia == 1 && ja == 1 && Am == A.m() && An == A.n()) + return A; + assert( (ia-1) % desc_mb( descA ) == 0 ); + assert( Am % desc_mb( descA ) == 0 ); + int64_t i1 = (ia-1)/desc_mb( descA ); + int64_t i2 = i1 + (Am/desc_mb( descA )) - 1; + return A.sub( i1, i2 ); } -template< typename scalar_t > -inline slate::TrapezoidMatrix slate_scalapack_submatrix(int Am, int An, slate::TrapezoidMatrix& A, int ia, int ja, int* desca) +template +slate::TrapezoidMatrix slate_scalapack_submatrix( + blas_int Am, blas_int An, slate::TrapezoidMatrix& A, + blas_int ia, blas_int ja, blas_int const* descA ) { - if (ia == 1 && ja == 1 && Am == A.m() && An == A.n()) return A; - assert((ia-1) % desc_NB(desca) == 0); - assert(An % desc_NB(desca) == 0); - int64_t i1 = (ia-1)/desc_NB(desca); - int64_t i2 = i1 + (Am/desc_NB(desca)) - 1; - return A.sub(i1, i2, i2); + if (ia == 1 && ja == 1 && Am == A.m() && An == A.n()) + return A; + assert( (ia-1) % desc_nb( descA ) == 0); + assert( An % desc_nb( descA) == 0); + int64_t i1 = (ia-1)/desc_nb( descA ); + int64_t i2 = i1 + (Am/desc_nb( descA )) - 1; + return A.sub( i1, i2, i2 ); } -template< typename scalar_t > -inline slate::HermitianMatrix slate_scalapack_submatrix(int Am, int An, slate::HermitianMatrix& A, int ia, int ja, int* desca) +template +slate::HermitianMatrix slate_scalapack_submatrix( + blas_int Am, blas_int An, slate::HermitianMatrix& A, + blas_int ia, blas_int ja, blas_int const* descA ) { - if (ia == 1 && ja == 1 && Am == A.m() && An == A.n()) return A; - assert((ia-1) % desc_NB(desca) == 0); - assert(An % desc_NB(desca) == 0); - int64_t i1 = (ia-1)/desc_NB(desca); - int64_t i2 = i1 + (Am/desc_NB(desca)) - 1; - return A.sub(i1, i2); + if (ia == 1 && ja == 1 && Am == A.m() && An == A.n()) + return A; + assert( (ia-1) % desc_nb( descA ) == 0 ); + assert( An % desc_nb( descA ) == 0 ); + int64_t i1 = (ia-1)/desc_nb( descA ); + int64_t i2 = i1 + (Am/desc_nb( descA )) - 1; + return A.sub( i1, i2 ); } //============================================================================== @@ -384,32 +449,55 @@ private: int64_t lookahead_; }; -// ----------------------------------------------------------------------------- -// helper funtion to check and do type conversion -// TODO: this is duplicated at the testing module -inline int int64_to_int(int64_t n) +//============================================================================== +// This is duplicated from blaspp/src/blas_internal.hh + +//------------------------------------------------------------------------------ +/// @see to_blas_int +/// +inline blas_int to_blas_int_( int64_t x, const char* x_str ) { - if (sizeof(int64_t) > sizeof(blas_int)) - assert(n < std::numeric_limits::max()); - int n_ = (int)n; - return n_; + if constexpr (sizeof(int64_t) > sizeof(blas_int)) { + blas_error_if_msg( x > std::numeric_limits::max(), "%s", x_str ); + } + return blas_int( x ); } +//---------------------------------------- +/// Convert int64_t to blas_int. +/// If blas_int is 64-bit, this does nothing. +/// If blas_int is 32-bit, throws if x > INT_MAX, so conversion would overflow. +/// +/// Note this is in src/blas_internal.hh, so this macro won't pollute +/// the namespace when apps #include . +/// +#define to_blas_int( x ) to_blas_int_( x, #x ) + +//============================================================================== + + // ----------------------------------------------------------------------------- // TODO: this is duplicated at the testing module -#define scalapack_numroc BLAS_FORTRAN_NAME(numroc,NUMROC) -extern "C" int scalapack_numroc(int* n, int* nb, int* iproc, int* isrcproc, int* nprocs); -inline int64_t scalapack_numroc(int64_t n, int64_t nb, int iproc, int isrcproc, int nprocs) +#define SCALAPACK_numroc BLAS_FORTRAN_NAME( numroc, NUMROC ) +extern "C" +blas_int SCALAPACK_numroc( + blas_int* n, blas_int* nb, blas_int* iproc, blas_int* isrcproc, + blas_int* nprocs ); + +inline int64_t scalapack_numroc( + int64_t n, int64_t nb, blas_int iproc, blas_int isrcproc, blas_int nprocs ) { - int n_ = int64_to_int(n); - int nb_ = int64_to_int(nb); - int nroc_ = scalapack_numroc(&n_, &nb_, &iproc, &isrcproc, &nprocs); - int64_t nroc = (int64_t)nroc_; - return nroc; + blas_int n_ = to_blas_int( n ); + blas_int nb_ = to_blas_int( nb ); + blas_int nroc_ = SCALAPACK_numroc( &n_, &nb_, &iproc, &isrcproc, &nprocs ); + return nroc_; } -#define scalapack_indxl2g BLAS_FORTRAN_NAME(indxl2g,INDXL2G) -extern "C" int scalapack_indxl2g(int* indxloc, int* nb, int* iproc, int* isrcproc, int* nprocs); +#define scalapack_indxl2g BLAS_FORTRAN_NAME( indxl2g, INDXL2G ) +extern "C" +blas_int scalapack_indxl2g( + blas_int* indxloc, blas_int* nb, blas_int* iproc, blas_int* isrcproc, + blas_int* nprocs ); } // namespace scalapack_api } // namespace slate diff --git a/scalapack_api/scalapack_symm.cc b/scalapack_api/scalapack_symm.cc index b5303494..304c42ef 100644 --- a/scalapack_api/scalapack_symm.cc +++ b/scalapack_api/scalapack_symm.cc @@ -11,118 +11,19 @@ namespace slate { namespace scalapack_api { -// ----------------------------------------------------------------------------- - -// Required CBLACS calls -extern "C" void Cblacs_gridinfo(int context, int* np_row, int* np_col, int* my_row, int* my_col); - -// Declarations -template< typename scalar_t > -void slate_psymm(const char* side, const char* uplo, int m, int n, scalar_t alpha, scalar_t* a, int ia, int ja, int* desca, scalar_t* b, int ib, int jb, int* descb, scalar_t beta, scalar_t* c, int ic, int jc, int* descc); - -// ----------------------------------------------------------------------------- -// Each C interface for all Fortran interfaces (FORTRAN_UPPER, FORTRAN_LOWER, FORTRAN_UNDERSCORE) -// Each C interface calls the type-generic C++ slate_pgemm routine. - -extern "C" void PDSYMM(const char* side, const char* uplo, int* m, int* n, double* alpha, double* a, int* ia, int* ja, int* desca, double* b, int* ib, int* jb, int* descb, double* beta, double* c, int* ic, int* jc, int* descc) -{ - slate_psymm(side, uplo, *m, *n, *alpha, a, *ia, *ja, desca, b, *ib, *jb, descb, *beta, c, *ic, *jc, descc); -} - -extern "C" void pdsymm(const char* side, const char* uplo, int* m, int* n, double* alpha, double* a, int* ia, int* ja, int* desca, double* b, int* ib, int* jb, int* descb, double* beta, double* c, int* ic, int* jc, int* descc) -{ - slate_psymm(side, uplo, *m, *n, *alpha, a, *ia, *ja, desca, b, *ib, *jb, descb, *beta, c, *ic, *jc, descc); -} - -extern "C" void pdsymm_(const char* side, const char* uplo, int* m, int* n, double* alpha, double* a, int* ia, int* ja, int* desca, double* b, int* ib, int* jb, int* descb, double* beta, double* c, int* ic, int* jc, int* descc) -{ - slate_psymm(side, uplo, *m, *n, *alpha, a, *ia, *ja, desca, b, *ib, *jb, descb, *beta, c, *ic, *jc, descc); -} - -// ----------------------------------------------------------------------------- - -extern "C" void PSSYMM(const char* side, const char* uplo, int* m, int* n, float* alpha, float* a, int* ia, int* ja, int* desca, float* b, int* ib, int* jb, int* descb, float* beta, float* c, int* ic, int* jc, int* descc) -{ - slate_psymm(side, uplo, *m, *n, *alpha, a, *ia, *ja, desca, b, *ib, *jb, descb, *beta, c, *ic, *jc, descc); -} - -extern "C" void pssymm(const char* side, const char* uplo, int* m, int* n, float* alpha, float* a, int* ia, int* ja, int* desca, float* b, int* ib, int* jb, int* descb, float* beta, float* c, int* ic, int* jc, int* descc) -{ - slate_psymm(side, uplo, *m, *n, *alpha, a, *ia, *ja, desca, b, *ib, *jb, descb, *beta, c, *ic, *jc, descc); -} - -extern "C" void pssymm_(const char* side, const char* uplo, int* m, int* n, float* alpha, float* a, int* ia, int* ja, int* desca, float* b, int* ib, int* jb, int* descb, float* beta, float* c, int* ic, int* jc, int* descc) -{ - slate_psymm(side, uplo, *m, *n, *alpha, a, *ia, *ja, desca, b, *ib, *jb, descb, *beta, c, *ic, *jc, descc); -} - -// ----------------------------------------------------------------------------- - -extern "C" void PCSYMM(const char* side, const char* uplo, int* m, int* n, std::complex* alpha, std::complex* a, int* ia, int* ja, int* desca, std::complex* b, int* ib, int* jb, int* descb, std::complex* beta, std::complex* c, int* ic, int* jc, int* descc) -{ - slate_psymm(side, uplo, *m, *n, *alpha, a, *ia, *ja, desca, b, *ib, *jb, descb, *beta, c, *ic, *jc, descc); -} - -extern "C" void pcsymm(const char* side, const char* uplo, int* m, int* n, std::complex* alpha, std::complex* a, int* ia, int* ja, int* desca, std::complex* b, int* ib, int* jb, int* descb, std::complex* beta, std::complex* c, int* ic, int* jc, int* descc) -{ - slate_psymm(side, uplo, *m, *n, *alpha, a, *ia, *ja, desca, b, *ib, *jb, descb, *beta, c, *ic, *jc, descc); -} - -extern "C" void pcsymm_(const char* side, const char* uplo, int* m, int* n, std::complex* alpha, std::complex* a, int* ia, int* ja, int* desca, std::complex* b, int* ib, int* jb, int* descb, std::complex* beta, std::complex* c, int* ic, int* jc, int* descc) -{ - slate_psymm(side, uplo, *m, *n, *alpha, a, *ia, *ja, desca, b, *ib, *jb, descb, *beta, c, *ic, *jc, descc); -} - -// ----------------------------------------------------------------------------- - -extern "C" void PZSYMM(const char* side, const char* uplo, int* m, int* n, std::complex* alpha, std::complex* a, int* ia, int* ja, int* desca, std::complex* b, int* ib, int* jb, int* descb, std::complex* beta, std::complex* c, int* ic, int* jc, int* descc) -{ - slate_psymm(side, uplo, *m, *n, *alpha, a, *ia, *ja, desca, b, *ib, *jb, descb, *beta, c, *ic, *jc, descc); -} - -extern "C" void pzsymm(const char* side, const char* uplo, int* m, int* n, std::complex* alpha, std::complex* a, int* ia, int* ja, int* desca, std::complex* b, int* ib, int* jb, int* descb, std::complex* beta, std::complex* c, int* ic, int* jc, int* descc) -{ - slate_psymm(side, uplo, *m, *n, *alpha, a, *ia, *ja, desca, b, *ib, *jb, descb, *beta, c, *ic, *jc, descc); -} - -extern "C" void pzsymm_(const char* side, const char* uplo, int* m, int* n, std::complex* alpha, std::complex* a, int* ia, int* ja, int* desca, std::complex* b, int* ib, int* jb, int* descb, std::complex* beta, std::complex* c, int* ic, int* jc, int* descc) -{ - slate_psymm(side, uplo, *m, *n, *alpha, a, *ia, *ja, desca, b, *ib, *jb, descb, *beta, c, *ic, *jc, descc); -} - -// ----------------------------------------------------------------------------- -// Exposed type-specific API - -extern "C" void slate_pdsymm(const char* side, const char* uplo, int* m, int* n, double* alpha, double* a, int* ia, int* ja, int* desca, double* b, int* ib, int* jb, int* descb, double* beta, double* c, int* ic, int* jc, int* descc) -{ - slate_psymm(side, uplo, *m, *n, *alpha, a, *ia, *ja, desca, b, *ib, *jb, descb, *beta, c, *ic, *jc, descc); -} - -extern "C" void slate_pssymm(const char* side, const char* uplo, int* m, int* n, float* alpha, float* a, int* ia, int* ja, int* desca, float* b, int* ib, int* jb, int* descb, float* beta, float* c, int* ic, int* jc, int* descc) -{ - slate_psymm(side, uplo, *m, *n, *alpha, a, *ia, *ja, desca, b, *ib, *jb, descb, *beta, c, *ic, *jc, descc); -} - -extern "C" void slate_pcsymm(const char* side, const char* uplo, int* m, int* n, std::complex* alpha, std::complex* a, int* ia, int* ja, int* desca, std::complex* b, int* ib, int* jb, int* descb, std::complex* beta, std::complex* c, int* ic, int* jc, int* descc) -{ - slate_psymm(side, uplo, *m, *n, *alpha, a, *ia, *ja, desca, b, *ib, *jb, descb, *beta, c, *ic, *jc, descc); -} - -extern "C" void slate_pzsymm(const char* side, const char* uplo, int* m, int* n, std::complex* alpha, std::complex* a, int* ia, int* ja, int* desca, std::complex* b, int* ib, int* jb, int* descb, std::complex* beta, std::complex* c, int* ic, int* jc, int* descc) -{ - slate_psymm(side, uplo, *m, *n, *alpha, a, *ia, *ja, desca, b, *ib, *jb, descb, *beta, c, *ic, *jc, descc); -} - -// ----------------------------------------------------------------------------- +//------------------------------------------------------------------------------ // Type generic function calls the SLATE routine -template< typename scalar_t > -void slate_psymm(const char* sidestr, const char* uplostr, int m, int n, scalar_t alpha, scalar_t* a, int ia, int ja, int* desca, scalar_t* b, int ib, int jb, int* descb, scalar_t beta, scalar_t* c, int ic, int jc, int* descc) +template +void slate_psymm(const char* side_str, const char* uplo_str, blas_int m, blas_int n, scalar_t alpha, + scalar_t* A_data, blas_int ia, blas_int ja, blas_int const* descA, + scalar_t* B_data, blas_int ib, blas_int jb, blas_int const* descB, scalar_t beta, + scalar_t* C_data, blas_int ic, blas_int jc, blas_int const* descC ) { Side side{}; Uplo uplo{}; - from_string( std::string( 1, sidestr[0] ), &side ); - from_string( std::string( 1, uplostr[0] ), &uplo ); + from_string( std::string( 1, side_str[0] ), &side ); + from_string( std::string( 1, uplo_str[0] ), &uplo ); slate::Target target = TargetConfig::value(); int verbose = VerboseConfig::value(); @@ -137,34 +38,114 @@ void slate_psymm(const char* sidestr, const char* uplostr, int m, int n, scalar_ int64_t Cn = n; // create SLATE matrices from the ScaLAPACK layouts - int nprow, npcol, myprow, mypcol; - Cblacs_gridinfo(desc_CTXT(desca), &nprow, &npcol, &myprow, &mypcol); - auto AS = slate::SymmetricMatrix::fromScaLAPACK(uplo, desc_N(desca), a, desc_LLD(desca), desc_NB(desca), grid_order, nprow, npcol, MPI_COMM_WORLD); - AS = slate_scalapack_submatrix(Am, An, AS, ia, ja, desca); - - Cblacs_gridinfo(desc_CTXT(descb), &nprow, &npcol, &myprow, &mypcol); - auto B = slate::Matrix::fromScaLAPACK(desc_M(descb), desc_N(descb), b, desc_LLD(descb), desc_MB(descb), desc_NB(descb), grid_order, nprow, npcol, MPI_COMM_WORLD); - B = slate_scalapack_submatrix(Bm, Bn, B, ib, jb, descb); - - Cblacs_gridinfo(desc_CTXT(descc), &nprow, &npcol, &myprow, &mypcol); - auto C = slate::Matrix::fromScaLAPACK(desc_M(descc), desc_N(descc), c, desc_LLD(descc), desc_MB(descc), desc_NB(descc), grid_order, nprow, npcol, MPI_COMM_WORLD); - C = slate_scalapack_submatrix(Cm, Cn, C, ic, jc, descc); + blas_int nprow, npcol, myprow, mypcol; + Cblacs_gridinfo( desc_ctxt( descA ), &nprow, &npcol, &myprow, &mypcol ); + auto AS = slate::SymmetricMatrix::fromScaLAPACK( + uplo, desc_n( descA ), A_data, desc_lld( descA ), desc_nb( descA ), + grid_order, nprow, npcol, MPI_COMM_WORLD ); + AS = slate_scalapack_submatrix( Am, An, AS, ia, ja, descA ); + + Cblacs_gridinfo( desc_ctxt( descB ), &nprow, &npcol, &myprow, &mypcol ); + auto B = slate::Matrix::fromScaLAPACK( + desc_m( descB ), desc_n( descB ), B_data, desc_lld( descB ), + desc_mb( descB ), desc_nb( descB ), + grid_order, nprow, npcol, MPI_COMM_WORLD ); + B = slate_scalapack_submatrix( Bm, Bn, B, ib, jb, descB ); + + Cblacs_gridinfo( desc_ctxt( descC ), &nprow, &npcol, &myprow, &mypcol ); + auto C = slate::Matrix::fromScaLAPACK( + desc_m( descC ), desc_n( descC ), C_data, desc_lld( descC ), + desc_mb( descC ), desc_nb( descC ), + grid_order, nprow, npcol, MPI_COMM_WORLD ); + C = slate_scalapack_submatrix( Cm, Cn, C, ic, jc, descC ); if (side == blas::Side::Left) - assert(AS.mt() == C.mt()); + assert( AS.mt() == C.mt() ); else - assert(AS.mt() == C.nt()); - assert(B.mt() == C.mt()); - assert(B.nt() == C.nt()); + assert( AS.mt() == C.nt() ); + assert( B.mt() == C.mt() ); + assert( B.nt() == C.nt() ); if (verbose && myprow == 0 && mypcol == 0) logprintf("%s\n", "symm"); - slate::symm(side, alpha, AS, B, beta, C, { + slate::symm( side, alpha, AS, B, beta, C, { {slate::Option::Lookahead, lookahead}, {slate::Option::Target, target} }); } +//------------------------------------------------------------------------------ +// Each Fortran interface for all Fortran interfaces +// Each Fortran interface calls the type-generic C++ slate_pgemm routine. + +extern "C" { + +#define SCALAPACK_pssymm BLAS_FORTRAN_NAME( pssymm, PSSYMM ) +void SCALAPACK_pssymm( + const char* side, const char* uplo, + blas_int const* m, blas_int const* n, float* alpha, + float* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + float* B_data, blas_int const* ib, blas_int const* jb, blas_int const* descB, + float* beta, + float* C_data, blas_int const* ic, blas_int const* jc, blas_int const* descC ) +{ + slate_psymm( + side, uplo, *m, *n, *alpha, + A_data, *ia, *ja, descA, + B_data, *ib, *jb, descB, *beta, + C_data, *ic, *jc, descC ); +} + +#define SCALAPACK_pdsymm BLAS_FORTRAN_NAME( pdsymm, PDSYMM ) +void SCALAPACK_pdsymm( + const char* side, const char* uplo, + blas_int const* m, blas_int const* n, double* alpha, + double* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + double* B_data, blas_int const* ib, blas_int const* jb, blas_int const* descB, + double* beta, + double* C_data, blas_int const* ic, blas_int const* jc, blas_int const* descC ) +{ + slate_psymm( + side, uplo, *m, *n, *alpha, + A_data, *ia, *ja, descA, + B_data, *ib, *jb, descB, *beta, + C_data, *ic, *jc, descC ); +} + +#define SCALAPACK_pcsymm BLAS_FORTRAN_NAME( pcsymm, PCSYMM ) +void SCALAPACK_pcsymm( + const char* side, const char* uplo, + blas_int const* m, blas_int const* n, std::complex* alpha, + std::complex* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + std::complex* B_data, blas_int const* ib, blas_int const* jb, blas_int const* descB, + std::complex* beta, + std::complex* C_data, blas_int const* ic, blas_int const* jc, blas_int const* descC ) +{ + slate_psymm( + side, uplo, *m, *n, *alpha, + A_data, *ia, *ja, descA, + B_data, *ib, *jb, descB, *beta, + C_data, *ic, *jc, descC ); +} + +#define SCALAPACK_pzsymm BLAS_FORTRAN_NAME( pzsymm, PZSYMM ) +void SCALAPACK_pzsymm( + const char* side, const char* uplo, + blas_int const* m, blas_int const* n, std::complex* alpha, + std::complex* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + std::complex* B_data, blas_int const* ib, blas_int const* jb, blas_int const* descB, + std::complex* beta, + std::complex* C_data, blas_int const* ic, blas_int const* jc, blas_int const* descC ) +{ + slate_psymm( + side, uplo, *m, *n, *alpha, + A_data, *ia, *ja, descA, + B_data, *ib, *jb, descB, *beta, + C_data, *ic, *jc, descC ); +} + +} // extern "C" + } // namespace scalapack_api } // namespace slate diff --git a/scalapack_api/scalapack_syr2k.cc b/scalapack_api/scalapack_syr2k.cc index 1d6ccd20..9b57a1cc 100644 --- a/scalapack_api/scalapack_syr2k.cc +++ b/scalapack_api/scalapack_syr2k.cc @@ -11,95 +11,22 @@ namespace slate { namespace scalapack_api { -// ----------------------------------------------------------------------------- - -// Required CBLACS calls -extern "C" void Cblacs_gridinfo(int context, int* np_row, int* np_col, int* my_row, int* my_col); - -// Declarations -template< typename scalar_t > -void slate_psyr2k(const char* uplostr, const char* transstr, int n, int k, scalar_t alpha, scalar_t* a, int ia, int ja, int* desca, scalar_t* b, int ib, int jb, int* descb, scalar_t beta, scalar_t* c, int ic, int jc, int* descc); - -// ----------------------------------------------------------------------------- -// C interfaces (FORTRAN_UPPER, FORTRAN_LOWER, FORTRAN_UNDERSCORE) -// Each C interface calls the type generic slate_psyr2k - -extern "C" void PDSYR2K(const char* uplo, const char* trans, int* n, int* k, double* alpha, double* a, int* ia, int* ja, int* desca, double* b, int* ib, int* jb, int* descb, double* beta, double* c, int* ic, int* jc, int* descc) -{ - slate_psyr2k(uplo, trans, *n, *k, *alpha, a, *ia, *ja, desca, b, *ib, *jb, descb, *beta, c, *ic, *jc, descc); -} - -extern "C" void pdsyr2k(const char* uplo, const char* trans, int* n, int* k, double* alpha, double* a, int* ia, int* ja, int* desca, double* b, int* ib, int* jb, int* descb, double* beta, double* c, int* ic, int* jc, int* descc) -{ - slate_psyr2k(uplo, trans, *n, *k, *alpha, a, *ia, *ja, desca, b, *ib, *jb, descb, *beta, c, *ic, *jc, descc); -} - -extern "C" void pdsyr2k_(const char* uplo, const char* trans, int* n, int* k, double* alpha, double* a, int* ia, int* ja, int* desca, double* b, int* ib, int* jb, int* descb, double* beta, double* c, int* ic, int* jc, int* descc) -{ - slate_psyr2k(uplo, trans, *n, *k, *alpha, a, *ia, *ja, desca, b, *ib, *jb, descb, *beta, c, *ic, *jc, descc); -} - -// ----------------------------------------------------------------------------- - -extern "C" void PSSYR2K(const char* uplo, const char* trans, int* n, int* k, float* alpha, float* a, int* ia, int* ja, int* desca, float* b, int* ib, int* jb, int* descb, float* beta, float* c, int* ic, int* jc, int* descc) -{ - slate_psyr2k(uplo, trans, *n, *k, *alpha, a, *ia, *ja, desca, b, *ib, *jb, descb, *beta, c, *ic, *jc, descc); -} - -extern "C" void pssyr2k(const char* uplo, const char* trans, int* n, int* k, float* alpha, float* a, int* ia, int* ja, int* desca, float* b, int* ib, int* jb, int* descb, float* beta, float* c, int* ic, int* jc, int* descc) -{ - slate_psyr2k(uplo, trans, *n, *k, *alpha, a, *ia, *ja, desca, b, *ib, *jb, descb, *beta, c, *ic, *jc, descc); -} - -extern "C" void pssyr2k_(const char* uplo, const char* trans, int* n, int* k, float* alpha, float* a, int* ia, int* ja, int* desca, float* b, int* ib, int* jb, int* descb, float* beta, float* c, int* ic, int* jc, int* descc) -{ - slate_psyr2k(uplo, trans, *n, *k, *alpha, a, *ia, *ja, desca, b, *ib, *jb, descb, *beta, c, *ic, *jc, descc); -} - -// ----------------------------------------------------------------------------- - -extern "C" void PCSYR2K(const char* uplo, const char* trans, int* n, int* k, std::complex* alpha, std::complex* a, int* ia, int* ja, int* desca, std::complex* b, int* ib, int* jb, int* descb, std::complex* beta, std::complex* c, int* ic, int* jc, int* descc) -{ - slate_psyr2k(uplo, trans, *n, *k, *alpha, a, *ia, *ja, desca, b, *ib, *jb, descb, *beta, c, *ic, *jc, descc); -} - -extern "C" void pcsyr2k(const char* uplo, const char* trans, int* n, int* k, std::complex* alpha, std::complex* a, int* ia, int* ja, int* desca, std::complex* b, int* ib, int* jb, int* descb, std::complex* beta, std::complex* c, int* ic, int* jc, int* descc) -{ - slate_psyr2k(uplo, trans, *n, *k, *alpha, a, *ia, *ja, desca, b, *ib, *jb, descb, *beta, c, *ic, *jc, descc); -} - -extern "C" void pcsyr2k_(const char* uplo, const char* trans, int* n, int* k, std::complex* alpha, std::complex* a, int* ia, int* ja, int* desca, std::complex* b, int* ib, int* jb, int* descb, std::complex* beta, std::complex* c, int* ic, int* jc, int* descc) -{ - slate_psyr2k(uplo, trans, *n, *k, *alpha, a, *ia, *ja, desca, b, *ib, *jb, descb, *beta, c, *ic, *jc, descc); -} - -// ----------------------------------------------------------------------------- - -extern "C" void PZSYR2K(const char* uplo, const char* trans, int* n, int* k, std::complex* alpha, std::complex* a, int* ia, int* ja, int* desca, std::complex* b, int* ib, int* jb, int* descb, std::complex* beta, std::complex* c, int* ic, int* jc, int* descc) -{ - slate_psyr2k(uplo, trans, *n, *k, *alpha, a, *ia, *ja, desca, b, *ib, *jb, descb, *beta, c, *ic, *jc, descc); -} - -extern "C" void pzsyr2k(const char* uplo, const char* trans, int* n, int* k, std::complex* alpha, std::complex* a, int* ia, int* ja, int* desca, std::complex* b, int* ib, int* jb, int* descb, std::complex* beta, std::complex* c, int* ic, int* jc, int* descc) -{ - slate_psyr2k(uplo, trans, *n, *k, *alpha, a, *ia, *ja, desca, b, *ib, *jb, descb, *beta, c, *ic, *jc, descc); -} - -extern "C" void pzsyr2k_(const char* uplo, const char* trans, int* n, int* k, std::complex* alpha, std::complex* a, int* ia, int* ja, int* desca, std::complex* b, int* ib, int* jb, int* descb, std::complex* beta, std::complex* c, int* ic, int* jc, int* descc) -{ - slate_psyr2k(uplo, trans, *n, *k, *alpha, a, *ia, *ja, desca, b, *ib, *jb, descb, *beta, c, *ic, *jc, descc); -} - -// ----------------------------------------------------------------------------- - -// Type generic function calls the SLATE routine -template< typename scalar_t > -void slate_psyr2k(const char* uplostr, const char* transstr, int n, int k, scalar_t alpha, scalar_t* a, int ia, int ja, int* desca, scalar_t* b, int ib, int jb, int* descb, scalar_t beta, scalar_t* c, int ic, int jc, int* descc) +//------------------------------------------------------------------------------ +/// SLATE ScaLAPACK wrapper sets up SLATE matrices from ScaLAPACK descriptors +/// and calls SLATE. +template +void slate_psyr2k( + const char* uplo_str, const char* trans_str, + blas_int n, blas_int k, scalar_t alpha, + scalar_t* A_data, blas_int ia, blas_int ja, blas_int const* descA, + scalar_t* B_data, blas_int ib, blas_int jb, blas_int const* descB, + scalar_t beta, + scalar_t* C_data, blas_int ic, blas_int jc, blas_int const* descC ) { Uplo uplo{}; Op trans{}; - from_string( std::string( 1, uplostr[0] ), &uplo ); - from_string( std::string( 1, transstr[0] ), &trans ); + from_string( std::string( 1, uplo_str[0] ), &uplo ); + from_string( std::string( 1, trans_str[0] ), &trans ); slate::Target target = TargetConfig::value(); int verbose = VerboseConfig::value(); @@ -114,39 +41,119 @@ void slate_psyr2k(const char* uplostr, const char* transstr, int n, int k, scala int64_t Cn = n; // create SLATE matrices from the ScaLAPACK layouts - int nprow, npcol, myprow, mypcol; - Cblacs_gridinfo(desc_CTXT(desca), &nprow, &npcol, &myprow, &mypcol); - auto A = slate::Matrix::fromScaLAPACK(desc_M(desca), desc_N(desca), a, desc_LLD(desca), desc_MB(desca), desc_NB(desca), grid_order, nprow, npcol, MPI_COMM_WORLD); - A = slate_scalapack_submatrix(Am, An, A, ia, ja, desca); - - Cblacs_gridinfo(desc_CTXT(descb), &nprow, &npcol, &myprow, &mypcol); - auto B = slate::Matrix::fromScaLAPACK(desc_M(descb), desc_N(descb), b, desc_LLD(descb), desc_MB(descb), desc_NB(descb), grid_order, nprow, npcol, MPI_COMM_WORLD); - B = slate_scalapack_submatrix(Bm, Bn, B, ib, jb, descb); - - Cblacs_gridinfo(desc_CTXT(descc), &nprow, &npcol, &myprow, &mypcol); - auto C = slate::SymmetricMatrix::fromScaLAPACK(uplo, desc_N(descc), c, desc_LLD(descc), desc_NB(descc), grid_order, nprow, npcol, MPI_COMM_WORLD); - auto CS = slate_scalapack_submatrix(Cn, Cn, C, ic, jc, descc); + blas_int nprow, npcol, myprow, mypcol; + Cblacs_gridinfo( desc_ctxt( descA ), &nprow, &npcol, &myprow, &mypcol ); + auto A = slate::Matrix::fromScaLAPACK( + desc_m( descA ), desc_n( descA ), A_data, desc_lld( descA ), + desc_mb( descA ), desc_nb( descA ), + grid_order, nprow, npcol, MPI_COMM_WORLD ); + A = slate_scalapack_submatrix( Am, An, A, ia, ja, descA ); + + Cblacs_gridinfo( desc_ctxt( descB ), &nprow, &npcol, &myprow, &mypcol ); + auto B = slate::Matrix::fromScaLAPACK( + desc_m( descB ), desc_n( descB ), B_data, desc_lld( descB ), + desc_mb( descB ), desc_nb( descB ), + grid_order, nprow, npcol, MPI_COMM_WORLD ); + B = slate_scalapack_submatrix( Bm, Bn, B, ib, jb, descB ); + + Cblacs_gridinfo( desc_ctxt( descC ), &nprow, &npcol, &myprow, &mypcol ); + auto C = slate::SymmetricMatrix::fromScaLAPACK( + uplo, desc_n( descC ), C_data, desc_lld( descC ), desc_nb( descC ), + grid_order, nprow, npcol, MPI_COMM_WORLD ); + C = slate_scalapack_submatrix( Cn, Cn, C, ic, jc, descC ); if (trans == blas::Op::Trans) { - A = transpose(A); - B = transpose(B); + A = transpose( A ); + B = transpose( B ); } else if (trans == blas::Op::ConjTrans) { A = conj_transpose( A ); B = conj_transpose( B ); } - assert(A.mt() == CS.mt()); - assert(B.mt() == CS.mt()); - assert(A.nt() == B.nt()); + assert( A.mt() == C.mt() ); + assert( B.mt() == C.mt() ); + assert( A.nt() == B.nt() ); if (verbose && myprow == 0 && mypcol == 0) logprintf("%s\n", "syr2k"); - slate::syr2k(alpha, A, B, beta, CS, { + slate::syr2k( alpha, A, B, beta, C, { {slate::Option::Lookahead, lookahead}, {slate::Option::Target, target} }); } +//------------------------------------------------------------------------------ +// Fortran interfaces +// Each Fortran interface calls the type generic slate wrapper. + +extern "C" { + +#define SCALAPACK_pssyr2k BLAS_FORTRAN_NAME( pssyr2k, PSSYR2K ) +void SCALAPACK_pssyr2k( + const char* uplo, const char* trans, + blas_int const* n, blas_int const* k, float* alpha, + float* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + float* B_data, blas_int const* ib, blas_int const* jb, blas_int const* descB, + float* beta, + float* C_data, blas_int const* ic, blas_int const* jc, blas_int const* descC ) +{ + slate_psyr2k( + uplo, trans, *n, *k, *alpha, + A_data, *ia, *ja, descA, + B_data, *ib, *jb, descB, *beta, + C_data, *ic, *jc, descC ); +} + +#define SCALAPACK_pdsyr2k BLAS_FORTRAN_NAME( pdsyr2k, PDSYR2K ) +void SCALAPACK_pdsyr2k( + const char* uplo, const char* trans, + blas_int const* n, blas_int const* k, double* alpha, + double* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + double* B_data, blas_int const* ib, blas_int const* jb, blas_int const* descB, + double* beta, + double* C_data, blas_int const* ic, blas_int const* jc, blas_int const* descC ) +{ + slate_psyr2k( + uplo, trans, *n, *k, *alpha, + A_data, *ia, *ja, descA, + B_data, *ib, *jb, descB, *beta, + C_data, *ic, *jc, descC ); +} + +#define SCALAPACK_pcsyr2k BLAS_FORTRAN_NAME( pcsyr2k, PCSYR2K ) +void SCALAPACK_pcsyr2k( + const char* uplo, const char* trans, + blas_int const* n, blas_int const* k, std::complex* alpha, + std::complex* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + std::complex* B_data, blas_int const* ib, blas_int const* jb, blas_int const* descB, + std::complex* beta, + std::complex* C_data, blas_int const* ic, blas_int const* jc, blas_int const* descC ) +{ + slate_psyr2k( + uplo, trans, *n, *k, *alpha, + A_data, *ia, *ja, descA, + B_data, *ib, *jb, descB, *beta, + C_data, *ic, *jc, descC ); +} + +#define SCALAPACK_pzsyr2k BLAS_FORTRAN_NAME( pzsyr2k, PZSYR2K ) +void SCALAPACK_pzsyr2k( + const char* uplo, const char* trans, + blas_int const* n, blas_int const* k, std::complex* alpha, + std::complex* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + std::complex* B_data, blas_int const* ib, blas_int const* jb, blas_int const* descB, + std::complex* beta, + std::complex* C_data, blas_int const* ic, blas_int const* jc, blas_int const* descC ) +{ + slate_psyr2k( + uplo, trans, *n, *k, *alpha, + A_data, *ia, *ja, descA, + B_data, *ib, *jb, descB, *beta, + C_data, *ic, *jc, descC ); +} + +} // extern "C" + } // namespace scalapack_api } // namespace slate diff --git a/scalapack_api/scalapack_syrk.cc b/scalapack_api/scalapack_syrk.cc index 58a20d89..cb08897c 100644 --- a/scalapack_api/scalapack_syrk.cc +++ b/scalapack_api/scalapack_syrk.cc @@ -11,95 +11,21 @@ namespace slate { namespace scalapack_api { -// ----------------------------------------------------------------------------- - -// Required CBLACS calls -extern "C" void Cblacs_gridinfo(int context, int* np_row, int* np_col, int* my_row, int* my_col); - -// Declarations -template< typename scalar_t > -void slate_psyrk(const char* uplostr, const char* transstr, int n, int k, scalar_t alpha, scalar_t* a, int ia, int ja, int* desca, scalar_t beta, scalar_t* c, int ic, int jc, int* descc); - -// ----------------------------------------------------------------------------- -// C interfaces (FORTRAN_UPPER, FORTRAN_LOWER, FORTRAN_UNDERSCORE) -// Each C interface calls the type generic slate_psyrk - -extern "C" void PDSYRK(const char* uplo, const char* trans, int* n, int* k, double* alpha, double* a, int* ia, int* ja, int* desca, double* beta, double* c, int* ic, int* jc, int* descc) -{ - slate_psyrk(uplo, trans, *n, *k, *alpha, a, *ia, *ja, desca, *beta, c, *ic, *jc, descc); -} - -extern "C" void pdsyrk(const char* uplo, const char* trans, int* n, int* k, double* alpha, double* a, int* ia, int* ja, int* desca, double* beta, double* c, int* ic, int* jc, int* descc) -{ - slate_psyrk(uplo, trans, *n, *k, *alpha, a, *ia, *ja, desca, *beta, c, *ic, *jc, descc); -} - -extern "C" void pdsyrk_(const char* uplo, const char* trans, int* n, int* k, double* alpha, double* a, int* ia, int* ja, int* desca, double* beta, double* c, int* ic, int* jc, int* descc) -{ - slate_psyrk(uplo, trans, *n, *k, *alpha, a, *ia, *ja, desca, *beta, c, *ic, *jc, descc); -} - -// ----------------------------------------------------------------------------- - -extern "C" void PSSYRK(const char* uplo, const char* trans, int* n, int* k, float* alpha, float* a, int* ia, int* ja, int* desca, float* beta, float* c, int* ic, int* jc, int* descc) -{ - slate_psyrk(uplo, trans, *n, *k, *alpha, a, *ia, *ja, desca, *beta, c, *ic, *jc, descc); -} - -extern "C" void pssyrk(const char* uplo, const char* trans, int* n, int* k, float* alpha, float* a, int* ia, int* ja, int* desca, float* beta, float* c, int* ic, int* jc, int* descc) -{ - slate_psyrk(uplo, trans, *n, *k, *alpha, a, *ia, *ja, desca, *beta, c, *ic, *jc, descc); -} - -extern "C" void pssyrk_(const char* uplo, const char* trans, int* n, int* k, float* alpha, float* a, int* ia, int* ja, int* desca, float* beta, float* c, int* ic, int* jc, int* descc) -{ - slate_psyrk(uplo, trans, *n, *k, *alpha, a, *ia, *ja, desca, *beta, c, *ic, *jc, descc); -} - -// ----------------------------------------------------------------------------- - -extern "C" void PCSYRK(const char* uplo, const char* trans, int* n, int* k, std::complex* alpha, std::complex* a, int* ia, int* ja, int* desca, std::complex* beta, std::complex* c, int* ic, int* jc, int* descc) -{ - slate_psyrk(uplo, trans, *n, *k, *alpha, a, *ia, *ja, desca, *beta, c, *ic, *jc, descc); -} - -extern "C" void pcsyrk(const char* uplo, const char* trans, int* n, int* k, std::complex* alpha, std::complex* a, int* ia, int* ja, int* desca, std::complex* beta, std::complex* c, int* ic, int* jc, int* descc) -{ - slate_psyrk(uplo, trans, *n, *k, *alpha, a, *ia, *ja, desca, *beta, c, *ic, *jc, descc); -} - -extern "C" void pcsyrk_(const char* uplo, const char* trans, int* n, int* k, std::complex* alpha, std::complex* a, int* ia, int* ja, int* desca, std::complex* beta, std::complex* c, int* ic, int* jc, int* descc) -{ - slate_psyrk(uplo, trans, *n, *k, *alpha, a, *ia, *ja, desca, *beta, c, *ic, *jc, descc); -} - -// ----------------------------------------------------------------------------- - -extern "C" void PZSYRK(const char* uplo, const char* trans, int* n, int* k, std::complex* alpha, std::complex* a, int* ia, int* ja, int* desca, std::complex* beta, std::complex* c, int* ic, int* jc, int* descc) -{ - slate_psyrk(uplo, trans, *n, *k, *alpha, a, *ia, *ja, desca, *beta, c, *ic, *jc, descc); -} - -extern "C" void pzsyrk(const char* uplo, const char* trans, int* n, int* k, std::complex* alpha, std::complex* a, int* ia, int* ja, int* desca, std::complex* beta, std::complex* c, int* ic, int* jc, int* descc) -{ - slate_psyrk(uplo, trans, *n, *k, *alpha, a, *ia, *ja, desca, *beta, c, *ic, *jc, descc); -} - -extern "C" void pzsyrk_(const char* uplo, const char* trans, int* n, int* k, std::complex* alpha, std::complex* a, int* ia, int* ja, int* desca, std::complex* beta, std::complex* c, int* ic, int* jc, int* descc) -{ - slate_psyrk(uplo, trans, *n, *k, *alpha, a, *ia, *ja, desca, *beta, c, *ic, *jc, descc); -} - -// ----------------------------------------------------------------------------- - -// Type generic function calls the SLATE routine -template< typename scalar_t > -void slate_psyrk(const char* uplostr, const char* transstr, int n, int k, scalar_t alpha, scalar_t* a, int ia, int ja, int* desca, scalar_t beta, scalar_t* c, int ic, int jc, int* descc) +//------------------------------------------------------------------------------ +/// SLATE ScaLAPACK wrapper sets up SLATE matrices from ScaLAPACK descriptors +/// and calls SLATE. +template +void slate_psyrk( + const char* uplo_str, const char* trans_str, + blas_int n, blas_int k, scalar_t alpha, + scalar_t* A_data, blas_int ia, blas_int ja, blas_int const* descA, + scalar_t beta, + scalar_t* C_data, blas_int ic, blas_int jc, blas_int const* descC ) { Uplo uplo{}; Op transA{}; - from_string( std::string( 1, uplostr[0] ), &uplo ); - from_string( std::string( 1, transstr[0] ), &transA ); + from_string( std::string( 1, uplo_str[0] ), &uplo ); + from_string( std::string( 1, trans_str[0] ), &transA ); slate::Target target = TargetConfig::value(); int verbose = VerboseConfig::value(); @@ -113,29 +39,98 @@ void slate_psyrk(const char* uplostr, const char* transstr, int n, int k, scalar int64_t Cn = n; // create SLATE matrices from the ScaLAPACK layouts - int nprow, npcol, myprow, mypcol; - Cblacs_gridinfo(desc_CTXT(desca), &nprow, &npcol, &myprow, &mypcol); - auto A = slate::Matrix::fromScaLAPACK(desc_M(desca), desc_N(desca), a, desc_LLD(desca), desc_MB(desca), desc_NB(desca), grid_order, nprow, npcol, MPI_COMM_WORLD); - A = slate_scalapack_submatrix(Am, An, A, ia, ja, desca); - - Cblacs_gridinfo(desc_CTXT(descc), &nprow, &npcol, &myprow, &mypcol); - auto C = slate::SymmetricMatrix::fromScaLAPACK(uplo, desc_N(descc), c, desc_LLD(descc), desc_NB(descc), grid_order, nprow, npcol, MPI_COMM_WORLD); - C = slate_scalapack_submatrix(Cm, Cn, C, ic, jc, descc); + blas_int nprow, npcol, myprow, mypcol; + Cblacs_gridinfo( desc_ctxt( descA ), &nprow, &npcol, &myprow, &mypcol ); + auto A = slate::Matrix::fromScaLAPACK( + desc_m( descA ), desc_n( descA ), A_data, desc_lld( descA ), + desc_mb( descA ), desc_nb( descA ), + grid_order, nprow, npcol, MPI_COMM_WORLD ); + A = slate_scalapack_submatrix( Am, An, A, ia, ja, descA ); + + Cblacs_gridinfo( desc_ctxt( descC ), &nprow, &npcol, &myprow, &mypcol ); + auto C = slate::SymmetricMatrix::fromScaLAPACK( + uplo, desc_n( descC ), C_data, desc_lld( descC ), desc_nb( descC ), + grid_order, nprow, npcol, MPI_COMM_WORLD ); + C = slate_scalapack_submatrix( Cm, Cn, C, ic, jc, descC ); if (transA == blas::Op::Trans) - A = transpose(A); + A = transpose( A ); else if (transA == blas::Op::ConjTrans) A = conj_transpose( A ); - assert(A.mt() == C.mt()); + assert( A.mt() == C.mt() ); if (verbose && myprow == 0 && mypcol == 0) logprintf("%s\n", "syrk"); - slate::syrk(alpha, A, beta, C, { + slate::syrk( alpha, A, beta, C, { {slate::Option::Lookahead, lookahead}, {slate::Option::Target, target} }); } +//------------------------------------------------------------------------------ +// Fortran interfaces +// Each Fortran interface calls the type generic slate wrapper. + +extern "C" { + +#define SCALAPACK_pssyrk BLAS_FORTRAN_NAME( pssyrk, PSSYRK ) +void SCALAPACK_pssyrk( + const char* uplo, const char* trans, + blas_int const* n, blas_int const* k, float* alpha, + float* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + float* beta, + float* C_data, blas_int const* ic, blas_int const* jc, blas_int const* descC ) +{ + slate_psyrk( + uplo, trans, *n, *k, *alpha, + A_data, *ia, *ja, descA, *beta, + C_data, *ic, *jc, descC ); +} + +#define SCALAPACK_pdsyrk BLAS_FORTRAN_NAME( pdsyrk, PDSYRK ) +void SCALAPACK_pdsyrk( + const char* uplo, const char* trans, + blas_int const* n, blas_int const* k, double* alpha, + double* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + double* beta, + double* C_data, blas_int const* ic, blas_int const* jc, blas_int const* descC ) +{ + slate_psyrk( + uplo, trans, *n, *k, *alpha, + A_data, *ia, *ja, descA, *beta, + C_data, *ic, *jc, descC ); +} + +#define SCALAPACK_pcsyrk BLAS_FORTRAN_NAME( pcsyrk, PCSYRK ) +void SCALAPACK_pcsyrk( + const char* uplo, const char* trans, + blas_int const* n, blas_int const* k, std::complex* alpha, + std::complex* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + std::complex* beta, + std::complex* C_data, blas_int const* ic, blas_int const* jc, blas_int const* descC ) +{ + slate_psyrk( + uplo, trans, *n, *k, *alpha, + A_data, *ia, *ja, descA, *beta, + C_data, *ic, *jc, descC ); +} + +#define SCALAPACK_pzsyrk BLAS_FORTRAN_NAME( pzsyrk, PZSYRK ) +void SCALAPACK_pzsyrk( + const char* uplo, const char* trans, + blas_int const* n, blas_int const* k, std::complex* alpha, + std::complex* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + std::complex* beta, + std::complex* C_data, blas_int const* ic, blas_int const* jc, blas_int const* descC ) +{ + slate_psyrk( + uplo, trans, *n, *k, *alpha, + A_data, *ia, *ja, descA, *beta, + C_data, *ic, *jc, descC ); +} + +} // extern "C" + } // namespace scalapack_api } // namespace slate diff --git a/scalapack_api/scalapack_trcon.cc b/scalapack_api/scalapack_trcon.cc index 79876761..156fd6ca 100644 --- a/scalapack_api/scalapack_trcon.cc +++ b/scalapack_api/scalapack_trcon.cc @@ -8,95 +8,26 @@ namespace slate { namespace scalapack_api { -// ----------------------------------------------------------------------------- - -// Required CBLACS calls -extern "C" void Cblacs_gridinfo(int context, int* np_row, int* np_col, int* my_row, int* my_col); - -// Type generic function calls the SLATE routine -template< typename scalar_t > -void slate_ptrcon(const char* normstr, const char* uplostr, const char* diagstr, int n, scalar_t* a, int ia, int ja, int* desca, blas::real_type* rcond, scalar_t* work, int lwork, void* irwork, int lirwork, int* info); - -// ----------------------------------------------------------------------------- -// C interfaces (FORTRAN_UPPER, FORTRAN_LOWER, FORTRAN_UNDERSCORE) -// Each C interface calls the type generic slate_pher2k - -extern "C" void PSTRCON(const char* normstr, const char* uplostr, const char* diagstr, int* n, float* a, int* ia, int* ja, int* desca, float* rcond, float* work, int* lwork, int* iwork, int* liwork, int* info) -{ - slate_ptrcon(normstr, uplostr, diagstr, *n, a, *ia, *ja, desca, rcond, work, *lwork, iwork, *liwork, info); -} - -extern "C" void pstrcon(const char* normstr, const char* uplostr, const char* diagstr, int* n, float* a, int* ia, int* ja, int* desca, float* rcond, float* work, int* lwork, int* iwork, int* liwork, int* info) -{ - slate_ptrcon(normstr, uplostr, diagstr, *n, a, *ia, *ja, desca, rcond, work, *lwork, iwork, *liwork, info); -} - -extern "C" void pstrcon_(const char* normstr, const char* uplostr, const char* diagstr, int* n, float* a, int* ia, int* ja, int* desca, float* rcond, float* work, int* lwork, int* iwork, int* liwork, int* info) -{ - slate_ptrcon(normstr, uplostr, diagstr, *n, a, *ia, *ja, desca, rcond, work, *lwork, iwork, *liwork, info); -} - -// ----------------------------------------------------------------------------- - -extern "C" void PDTRCON(const char* normstr, const char* uplostr, const char* diagstr, int* n, double* a, int* ia, int* ja, int* desca, double* rcond, double* work, int* lwork, int* iwork, int* liwork, int* info) -{ - slate_ptrcon(normstr, uplostr, diagstr, *n, a, *ia, *ja, desca, rcond, work, *lwork, iwork, *liwork, info); -} - -extern "C" void pdtrcon(const char* normstr, const char* uplostr, const char* diagstr, int* n, double* a, int* ia, int* ja, int* desca, double* rcond, double* work, int* lwork, int* iwork, int* liwork, int* info) -{ - slate_ptrcon(normstr, uplostr, diagstr, *n, a, *ia, *ja, desca, rcond, work, *lwork, iwork, *liwork, info); -} - -extern "C" void pdtrcon_(const char* normstr, const char* uplostr, const char* diagstr, int* n, double* a, int* ia, int* ja, int* desca, double* rcond, double* work, int* lwork, int* iwork, int* liwork, int* info) -{ - slate_ptrcon(normstr, uplostr, diagstr, *n, a, *ia, *ja, desca, rcond, work, *lwork, iwork, *liwork, info); -} - -// ----------------------------------------------------------------------------- - -extern "C" void PCTRCON(const char* normstr, const char* uplostr, const char* diagstr, int* n, std::complex* a, int* ia, int* ja, int* desca, float* rcond, std::complex* work, int* lwork, float* rwork, int* lrwork, int* info) -{ - slate_ptrcon(normstr, uplostr, diagstr, *n, a, *ia, *ja, desca, rcond, work, *lwork, rwork, *lrwork, info); -} - -extern "C" void pctrcon(const char* normstr, const char* uplostr, const char* diagstr, int* n, std::complex* a, int* ia, int* ja, int* desca, float* rcond, std::complex* work, int* lwork, float* rwork, int* lrwork, int* info) -{ - slate_ptrcon(normstr, uplostr, diagstr, *n, a, *ia, *ja, desca, rcond, work, *lwork, rwork, *lrwork, info); -} - -extern "C" void pctrcon_(const char* normstr, const char* uplostr, const char* diagstr, int* n, std::complex* a, int* ia, int* ja, int* desca, float* rcond, std::complex* work, int* lwork, float* rwork, int* lrwork, int* info) -{ - slate_ptrcon(normstr, uplostr, diagstr, *n, a, *ia, *ja, desca, rcond, work, *lwork, rwork, *lrwork, info); -} - -// ----------------------------------------------------------------------------- - -extern "C" void PZTRCON(const char* normstr, const char* uplostr, const char* diagstr, int* n, std::complex* a, int* ia, int* ja, int* desca, double* rcond, std::complex* work, int* lwork, double* rwork, int* lrwork, int* info) -{ - slate_ptrcon(normstr, uplostr, diagstr, *n, a, *ia, *ja, desca, rcond, work, *lwork, rwork, *lrwork, info); -} - -extern "C" void pztrcon(const char* normstr, const char* uplostr, const char* diagstr, int* n, std::complex* a, int* ia, int* ja, int* desca, double* rcond, std::complex* work, int* lwork, double* rwork, int* lrwork, int* info) -{ - slate_ptrcon(normstr, uplostr, diagstr, *n, a, *ia, *ja, desca, rcond, work, *lwork, rwork, *lrwork, info); -} - -extern "C" void pztrcon_(const char* normstr, const char* uplostr, const char* diagstr, int* n, std::complex* a, int* ia, int* ja, int* desca, double* rcond, std::complex* work, int* lwork, double* rwork, int* lrwork, int* info) -{ - slate_ptrcon(normstr, uplostr, diagstr, *n, a, *ia, *ja, desca, rcond, work, *lwork, rwork, *lrwork, info); -} - -// ----------------------------------------------------------------------------- -template< typename scalar_t > -void slate_ptrcon(const char* normstr, const char* uplostr, const char* diagstr, int n, scalar_t* a, int ia, int ja, int* desca, blas::real_type* rcond, scalar_t* work, int lwork, void* irwork, int lirwork, int* info) +//------------------------------------------------------------------------------ +/// SLATE ScaLAPACK wrapper sets up SLATE matrices from ScaLAPACK descriptors +/// and calls SLATE. +/// If scalar_t is real, irwork is integer. +/// If scalar_t is complex, irwork is real. +template +void slate_ptrcon( + const char* norm_str, const char* uplo_str, const char* diag_str, blas_int n, + scalar_t* A_data, blas_int ia, blas_int ja, blas_int const* descA, + blas::real_type* rcond, + scalar_t* work, blas_int lwork, + void* irwork, blas_int lirwork, + blas_int* info ) { Norm norm{}; Uplo uplo{}; Diag diag{}; - from_string( std::string( 1, normstr[0] ), &norm ); - from_string( std::string( 1, uplostr[0] ), &uplo ); - from_string( std::string( 1, diagstr[0] ), &diag ); + from_string( std::string( 1, norm_str[0] ), &norm ); + from_string( std::string( 1, uplo_str[0] ), &uplo ); + from_string( std::string( 1, diag_str[0] ), &diag ); slate::Target target = TargetConfig::value(); int verbose = VerboseConfig::value(); @@ -108,15 +39,15 @@ void slate_ptrcon(const char* normstr, const char* uplostr, const char* diagstr, // todo: extract the real info from getrf *info = 0; - int nprow, npcol, myprow, mypcol; - Cblacs_gridinfo(desc_CTXT(desca), &nprow, &npcol, &myprow, &mypcol); + blas_int nprow, npcol, myprow, mypcol; + Cblacs_gridinfo( desc_ctxt( descA ), &nprow, &npcol, &myprow, &mypcol ); if (verbose && myprow == 0 && mypcol == 0) logprintf("%s\n", "trcon"); if (lwork == -1 || lirwork == -1) { *work = 0; if constexpr (std::is_same_v>) { - *(int*)irwork = 0; + *(blas_int*)irwork = 0; } else { *(blas::real_type*)irwork = 0; @@ -129,17 +60,19 @@ void slate_ptrcon(const char* normstr, const char* uplostr, const char* diagstr, int64_t An = n; // create SLATE matrices from the ScaLAPACK layouts - auto AT = slate::TriangularMatrix::fromScaLAPACK(uplo, diag, desc_N(desca), a, desc_LLD(desca), desc_NB(desca), grid_order, nprow, npcol, MPI_COMM_WORLD); - AT = slate_scalapack_submatrix(Am, An, AT, ia, ja, desca); + auto AT = slate::TriangularMatrix::fromScaLAPACK( + uplo, diag, desc_n( descA ), A_data, desc_lld( descA ), desc_nb( descA ), + grid_order, nprow, npcol, MPI_COMM_WORLD ); + AT = slate_scalapack_submatrix( Am, An, AT, ia, ja, descA ); - blas::real_type anorm = slate::norm( norm, AT, { + blas::real_type Anorm = slate::norm( norm, AT, { {slate::Option::Lookahead, lookahead}, {slate::Option::Target, target}, {slate::Option::MaxPanelThreads, panel_threads}, {slate::Option::InnerBlocking, ib} }); - *rcond = slate::trcondest( norm, AT, anorm, { + *rcond = slate::trcondest( norm, AT, Anorm, { {slate::Option::Lookahead, lookahead}, {slate::Option::Target, target}, {slate::Option::MaxPanelThreads, panel_threads}, @@ -147,5 +80,73 @@ void slate_ptrcon(const char* normstr, const char* uplostr, const char* diagstr, }); } +//------------------------------------------------------------------------------ +// Fortran interfaces +// Each Fortran interface calls the type generic slate wrapper. + +extern "C" { + +#define SCALAPACK_pstrcon BLAS_FORTRAN_NAME( pstrcon, PSTRCON ) +void SCALAPACK_pstrcon( + const char* norm, const char* uplo, const char* diag, blas_int const* n, + float* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + float* rcond, + float* work, blas_int const* lwork, + blas_int* iwork, blas_int const* liwork, + blas_int* info ) +{ + slate_ptrcon( + norm, uplo, diag, *n, + A_data, *ia, *ja, descA, + rcond, work, *lwork, iwork, *liwork, info ); +} + +#define SCALAPACK_pdtrcon BLAS_FORTRAN_NAME( pdtrcon, PDTRCON ) +void SCALAPACK_pdtrcon( + const char* norm, const char* uplo, const char* diag, blas_int const* n, + double* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + double* rcond, + double* work, blas_int const* lwork, + blas_int* iwork, blas_int const* liwork, + blas_int* info ) +{ + slate_ptrcon( + norm, uplo, diag, *n, + A_data, *ia, *ja, descA, + rcond, work, *lwork, iwork, *liwork, info ); +} + +#define SCALAPACK_pctrcon BLAS_FORTRAN_NAME( pctrcon, PCTRCON ) +void SCALAPACK_pctrcon( + const char* norm, const char* uplo, const char* diag, blas_int const* n, + std::complex* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + float* rcond, + std::complex* work, blas_int const* lwork, + float* rwork, blas_int const* lrwork, + blas_int* info ) +{ + slate_ptrcon( + norm, uplo, diag, *n, + A_data, *ia, *ja, descA, + rcond, work, *lwork, rwork, *lrwork, info ); +} + +#define SCALAPACK_pztrcon BLAS_FORTRAN_NAME( pztrcon, PZTRCON ) +void SCALAPACK_pztrcon( + const char* norm, const char* uplo, const char* diag, blas_int const* n, + std::complex* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + double* rcond, + std::complex* work, blas_int const* lwork, + double* rwork, blas_int const* lrwork, + blas_int* info ) +{ + slate_ptrcon( + norm, uplo, diag, *n, + A_data, *ia, *ja, descA, + rcond, work, *lwork, rwork, *lrwork, info ); +} + +} // extern "C" + } // namespace scalapack_api } // namespace slate diff --git a/scalapack_api/scalapack_trmm.cc b/scalapack_api/scalapack_trmm.cc index d840bdee..d93a3ea3 100644 --- a/scalapack_api/scalapack_trmm.cc +++ b/scalapack_api/scalapack_trmm.cc @@ -11,99 +11,25 @@ namespace slate { namespace scalapack_api { -// ----------------------------------------------------------------------------- - -// Required CBLACS calls -extern "C" void Cblacs_gridinfo(int context, int* np_row, int* np_col, int* my_row, int* my_col); - -// Declarations -template< typename scalar_t > -void slate_ptrmm(const char* side, const char* uplo, const char* transa, const char* diag, int m, int n, scalar_t alpha, scalar_t* a, int ia, int ja, int* desca, scalar_t* b, int ib, int jb, int* descb); - -// ----------------------------------------------------------------------------- -// C interfaces (FORTRAN_UPPER, FORTRAN_LOWER, FORTRAN_UNDERSCORE) -// Each C interface calls the type generic slate_ptrmm - -extern "C" void PDTRMM(const char* side, const char* uplo, const char* transa, const char* diag, int* m, int* n, double* alpha, double* a, int* ia, int* ja, int* desca, double* b, int* ib, int* jb, int* descb) -{ - slate_ptrmm(side, uplo, transa, diag, *m, *n, *alpha, a, *ia, *ja, desca, b, *ib, *jb, descb); -} - -extern "C" void pdtrmm(const char* side, const char* uplo, const char* transa, const char* diag, int* m, int* n, double* alpha, double* a, int* ia, int* ja, int* desca, double* b, int* ib, int* jb, int* descb) -{ - slate_ptrmm(side, uplo, transa, diag, *m, *n, *alpha, a, *ia, *ja, desca, b, *ib, *jb, descb); -} - -extern "C" void pdtrmm_(const char* side, const char* uplo, const char* transa, const char* diag, int* m, int* n, double* alpha, double* a, int* ia, int* ja, int* desca, double* b, int* ib, int* jb, int* descb) -{ - slate_ptrmm(side, uplo, transa, diag, *m, *n, *alpha, a, *ia, *ja, desca, b, *ib, *jb, descb); -} - -// ----------------------------------------------------------------------------- - -extern "C" void PSTRMM(const char* side, const char* uplo, const char* transa, const char* diag, int* m, int* n, float* alpha, float* a, int* ia, int* ja, int* desca, float* b, int* ib, int* jb, int* descb) -{ - slate_ptrmm(side, uplo, transa, diag, *m, *n, *alpha, a, *ia, *ja, desca, b, *ib, *jb, descb); -} - -extern "C" void pstrmm(const char* side, const char* uplo, const char* transa, const char* diag, int* m, int* n, float* alpha, float* a, int* ia, int* ja, int* desca, float* b, int* ib, int* jb, int* descb) -{ - slate_ptrmm(side, uplo, transa, diag, *m, *n, *alpha, a, *ia, *ja, desca, b, *ib, *jb, descb); -} - -extern "C" void pstrmm_(const char* side, const char* uplo, const char* transa, const char* diag, int* m, int* n, float* alpha, float* a, int* ia, int* ja, int* desca, float* b, int* ib, int* jb, int* descb) -{ - slate_ptrmm(side, uplo, transa, diag, *m, *n, *alpha, a, *ia, *ja, desca, b, *ib, *jb, descb); -} - -// ----------------------------------------------------------------------------- - -extern "C" void PCTRMM(const char* side, const char* uplo, const char* transa, const char* diag, int* m, int* n, std::complex* alpha, std::complex* a, int* ia, int* ja, int* desca, std::complex* b, int* ib, int* jb, int* descb) -{ - slate_ptrmm(side, uplo, transa, diag, *m, *n, *alpha, a, *ia, *ja, desca, b, *ib, *jb, descb); -} - -extern "C" void pctrmm(const char* side, const char* uplo, const char* transa, const char* diag, int* m, int* n, std::complex* alpha, std::complex* a, int* ia, int* ja, int* desca, std::complex* b, int* ib, int* jb, int* descb) -{ - slate_ptrmm(side, uplo, transa, diag, *m, *n, *alpha, a, *ia, *ja, desca, b, *ib, *jb, descb); -} - -extern "C" void pctrmm_(const char* side, const char* uplo, const char* transa, const char* diag, int* m, int* n, std::complex* alpha, std::complex* a, int* ia, int* ja, int* desca, std::complex* b, int* ib, int* jb, int* descb) -{ - slate_ptrmm(side, uplo, transa, diag, *m, *n, *alpha, a, *ia, *ja, desca, b, *ib, *jb, descb); -} - -// ----------------------------------------------------------------------------- - -extern "C" void PZTRMM(const char* side, const char* uplo, const char* transa, const char* diag, int* m, int* n, std::complex* alpha, std::complex* a, int* ia, int* ja, int* desca, std::complex* b, int* ib, int* jb, int* descb) -{ - slate_ptrmm(side, uplo, transa, diag, *m, *n, *alpha, a, *ia, *ja, desca, b, *ib, *jb, descb); -} - -extern "C" void pztrmm(const char* side, const char* uplo, const char* transa, const char* diag, int* m, int* n, std::complex* alpha, std::complex* a, int* ia, int* ja, int* desca, std::complex* b, int* ib, int* jb, int* descb) -{ - slate_ptrmm(side, uplo, transa, diag, *m, *n, *alpha, a, *ia, *ja, desca, b, *ib, *jb, descb); -} - -extern "C" void pztrmm_(const char* side, const char* uplo, const char* transa, const char* diag, int* m, int* n, std::complex* alpha, std::complex* a, int* ia, int* ja, int* desca, std::complex* b, int* ib, int* jb, int* descb) -{ - slate_ptrmm(side, uplo, transa, diag, *m, *n, *alpha, a, *ia, *ja, desca, b, *ib, *jb, descb); -} - -// ----------------------------------------------------------------------------- - -// Type generic function calls the SLATE routine -template< typename scalar_t > -void slate_ptrmm(const char* sidestr, const char* uplostr, const char* transastr, const char* diagstr, int m, int n, scalar_t alpha, scalar_t* a, int ia, int ja, int* desca, scalar_t* b, int ib, int jb, int* descb) +//------------------------------------------------------------------------------ +/// SLATE ScaLAPACK wrapper sets up SLATE matrices from ScaLAPACK descriptors +/// and calls SLATE. +template +void slate_ptrmm( + const char* side_str, const char* uplo_str, const char* transA_str, + const char* diag_str, + blas_int m, blas_int n, scalar_t alpha, + scalar_t* A_data, blas_int ia, blas_int ja, blas_int const* descA, + scalar_t* B_data, blas_int ib, blas_int jb, blas_int const* descB) { Side side{}; Uplo uplo{}; Op transA{}; Diag diag{}; - from_string( std::string( 1, sidestr[0] ), &side ); - from_string( std::string( 1, uplostr[0] ), &uplo ); - from_string( std::string( 1, transastr[0] ), &transA ); - from_string( std::string( 1, diagstr[0] ), &diag ); + from_string( std::string( 1, side_str[0] ), &side ); + from_string( std::string( 1, uplo_str[0] ), &uplo ); + from_string( std::string( 1, transA_str[0] ), &transA ); + from_string( std::string( 1, diag_str[0] ), &diag ); slate::Target target = TargetConfig::value(); int verbose = VerboseConfig::value(); @@ -117,28 +43,93 @@ void slate_ptrmm(const char* sidestr, const char* uplostr, const char* transastr int64_t Bn = n; // create SLATE matrices from the ScaLAPACK layouts - int nprow, npcol, myprow, mypcol; - Cblacs_gridinfo(desc_CTXT(desca), &nprow, &npcol, &myprow, &mypcol); - auto AT = slate::TriangularMatrix::fromScaLAPACK(uplo, diag, desc_N(desca), a, desc_LLD(desca), desc_NB(desca), grid_order, nprow, npcol, MPI_COMM_WORLD); - AT = slate_scalapack_submatrix(Am, An, AT, ia, ja, desca); - - Cblacs_gridinfo(desc_CTXT(descb), &nprow, &npcol, &myprow, &mypcol); - auto B = slate::Matrix::fromScaLAPACK(desc_M(descb), desc_N(descb), b, desc_LLD(descb), desc_MB(descb), desc_NB(descb), grid_order, nprow, npcol, MPI_COMM_WORLD); - B = slate_scalapack_submatrix(Bm, Bn, B, ib, jb, descb); + blas_int nprow, npcol, myprow, mypcol; + Cblacs_gridinfo( desc_ctxt( descA ), &nprow, &npcol, &myprow, &mypcol ); + auto AT = slate::TriangularMatrix::fromScaLAPACK( + uplo, diag, desc_n( descA ), A_data, desc_lld( descA ), desc_nb( descA ), + grid_order, nprow, npcol, MPI_COMM_WORLD ); + AT = slate_scalapack_submatrix( Am, An, AT, ia, ja, descA ); + + Cblacs_gridinfo( desc_ctxt( descB ), &nprow, &npcol, &myprow, &mypcol ); + auto B = slate::Matrix::fromScaLAPACK( + desc_m( descB ), desc_n( descB ), B_data, desc_lld( descB ), + desc_mb( descB ), desc_nb( descB ), + grid_order, nprow, npcol, MPI_COMM_WORLD ); + B = slate_scalapack_submatrix( Bm, Bn, B, ib, jb, descB ); if (transA == Op::Trans) - AT = transpose(AT); + AT = transpose( AT ); else if (transA == Op::ConjTrans) AT = conj_transpose( AT ); if (verbose && myprow == 0 && mypcol == 0) logprintf("%s\n", "trmm"); - slate::trmm(side, alpha, AT, B, { + slate::trmm( side, alpha, AT, B, { {slate::Option::Lookahead, lookahead}, {slate::Option::Target, target} }); } +//------------------------------------------------------------------------------ +// Fortran interfaces +// Each Fortran interface calls the type generic slate wrapper. + +extern "C" { + +#define SCALAPACK_pdtrmm BLAS_FORTRAN_NAME( pdtrmm, PDTRMM ) +void SCALAPACK_pdtrmm( + const char* side, const char* uplo, const char* transA, const char* diag, + blas_int const* m, blas_int const* n, double* alpha, + double* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + double* B_data, blas_int const* ib, blas_int const* jb, blas_int const* descB ) +{ + slate_ptrmm( + side, uplo, transA, diag, *m, *n, *alpha, + A_data, *ia, *ja, descA, + B_data, *ib, *jb, descB ); +} + +#define SCALAPACK_pstrmm BLAS_FORTRAN_NAME( pstrmm, PSTRMM ) +void SCALAPACK_pstrmm( + const char* side, const char* uplo, const char* transA, const char* diag, + blas_int const* m, blas_int const* n, float* alpha, + float* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + float* B_data, blas_int const* ib, blas_int const* jb, blas_int const* descB ) +{ + slate_ptrmm( + side, uplo, transA, diag, *m, *n, *alpha, + A_data, *ia, *ja, descA, + B_data, *ib, *jb, descB ); +} + +#define SCALAPACK_pctrmm BLAS_FORTRAN_NAME( pctrmm, PCTRMM ) +void SCALAPACK_pctrmm( + const char* side, const char* uplo, const char* transA, const char* diag, + blas_int const* m, blas_int const* n, std::complex* alpha, + std::complex* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + std::complex* B_data, blas_int const* ib, blas_int const* jb, blas_int const* descB ) +{ + slate_ptrmm( + side, uplo, transA, diag, *m, *n, *alpha, + A_data, *ia, *ja, descA, + B_data, *ib, *jb, descB ); +} + +#define SCALAPACK_pztrmm BLAS_FORTRAN_NAME( pztrmm, PZTRMM ) +void SCALAPACK_pztrmm( + const char* side, const char* uplo, const char* transA, const char* diag, + blas_int const* m, blas_int const* n, std::complex* alpha, + std::complex* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + std::complex* B_data, blas_int const* ib, blas_int const* jb, blas_int const* descB ) +{ + slate_ptrmm( + side, uplo, transA, diag, *m, *n, *alpha, + A_data, *ia, *ja, descA, + B_data, *ib, *jb, descB ); +} + +} // extern "C" + } // namespace scalapack_api } // namespace slate diff --git a/scalapack_api/scalapack_trsm.cc b/scalapack_api/scalapack_trsm.cc index 5f9b3409..b40faac9 100644 --- a/scalapack_api/scalapack_trsm.cc +++ b/scalapack_api/scalapack_trsm.cc @@ -11,99 +11,25 @@ namespace slate { namespace scalapack_api { -// ----------------------------------------------------------------------------- - -// Required CBLACS calls -extern "C" void Cblacs_gridinfo(int context, int* np_row, int* np_col, int* my_row, int* my_col); - -// Declarations -template -void slate_ptrsm(const char* sidestr, const char* uplostr, const char* transastr, const char* diagstr, int m, int n, scalar_t alpha, scalar_t* a, int ia, int ja, int* desca, scalar_t* b, int ib, int jb, int* descb); - -// ----------------------------------------------------------------------------- -// C interfaces (FORTRAN_UPPER, FORTRAN_LOWER, FORTRAN_UNDERSCORE) -// Each C interface calls the type generic slate_ptrsm - -extern "C" void PDTRSM(const char* side, const char* uplo, const char* transa, const char* diag, int* m, int* n, double* alpha, double* a, int* ia, int* ja, int* desca, double* b, int* ib, int* jb, int* descb) -{ - slate_ptrsm(side, uplo, transa, diag, *m, *n, *alpha, a, *ia, *ja, desca, b, *ib, *jb, descb); -} - -extern "C" void pdtrsm(const char* side, const char* uplo, const char* transa, const char* diag, int* m, int* n, double* alpha, double* a, int* ia, int* ja, int* desca, double* b, int* ib, int* jb, int* descb) -{ - slate_ptrsm(side, uplo, transa, diag, *m, *n, *alpha, a, *ia, *ja, desca, b, *ib, *jb, descb); -} - -extern "C" void pdtrsm_(const char* side, const char* uplo, const char* transa, const char* diag, int* m, int* n, double* alpha, double* a, int* ia, int* ja, int* desca, double* b, int* ib, int* jb, int* descb) -{ - slate_ptrsm(side, uplo, transa, diag, *m, *n, *alpha, a, *ia, *ja, desca, b, *ib, *jb, descb); -} - -// ----------------------------------------------------------------------------- - -extern "C" void PSTRSM(const char* side, const char* uplo, const char* transa, const char* diag, int* m, int* n, float* alpha, float* a, int* ia, int* ja, int* desca, float* b, int* ib, int* jb, int* descb) -{ - slate_ptrsm(side, uplo, transa, diag, *m, *n, *alpha, a, *ia, *ja, desca, b, *ib, *jb, descb); -} - -extern "C" void pstrsm(const char* side, const char* uplo, const char* transa, const char* diag, int* m, int* n, float* alpha, float* a, int* ia, int* ja, int* desca, float* b, int* ib, int* jb, int* descb) -{ - slate_ptrsm(side, uplo, transa, diag, *m, *n, *alpha, a, *ia, *ja, desca, b, *ib, *jb, descb); -} - -extern "C" void pstrsm_(const char* side, const char* uplo, const char* transa, const char* diag, int* m, int* n, float* alpha, float* a, int* ia, int* ja, int* desca, float* b, int* ib, int* jb, int* descb) -{ - slate_ptrsm(side, uplo, transa, diag, *m, *n, *alpha, a, *ia, *ja, desca, b, *ib, *jb, descb); -} - -// ----------------------------------------------------------------------------- - -extern "C" void PCTRSM(const char* side, const char* uplo, const char* transa, const char* diag, int* m, int* n, std::complex* alpha, std::complex* a, int* ia, int* ja, int* desca, std::complex* b, int* ib, int* jb, int* descb) -{ - slate_ptrsm(side, uplo, transa, diag, *m, *n, *alpha, a, *ia, *ja, desca, b, *ib, *jb, descb); -} - -extern "C" void pctrsm(const char* side, const char* uplo, const char* transa, const char* diag, int* m, int* n, std::complex* alpha, std::complex* a, int* ia, int* ja, int* desca, std::complex* b, int* ib, int* jb, int* descb) -{ - slate_ptrsm(side, uplo, transa, diag, *m, *n, *alpha, a, *ia, *ja, desca, b, *ib, *jb, descb); -} - -extern "C" void pctrsm_(const char* side, const char* uplo, const char* transa, const char* diag, int* m, int* n, std::complex* alpha, std::complex* a, int* ia, int* ja, int* desca, std::complex* b, int* ib, int* jb, int* descb) -{ - slate_ptrsm(side, uplo, transa, diag, *m, *n, *alpha, a, *ia, *ja, desca, b, *ib, *jb, descb); -} - -// ----------------------------------------------------------------------------- - -extern "C" void PZTRSM(const char* side, const char* uplo, const char* transa, const char* diag, int* m, int* n, std::complex* alpha, std::complex* a, int* ia, int* ja, int* desca, std::complex* b, int* ib, int* jb, int* descb) -{ - slate_ptrsm(side, uplo, transa, diag, *m, *n, *alpha, a, *ia, *ja, desca, b, *ib, *jb, descb); -} - -extern "C" void pztrsm(const char* side, const char* uplo, const char* transa, const char* diag, int* m, int* n, std::complex* alpha, std::complex* a, int* ia, int* ja, int* desca, std::complex* b, int* ib, int* jb, int* descb) -{ - slate_ptrsm(side, uplo, transa, diag, *m, *n, *alpha, a, *ia, *ja, desca, b, *ib, *jb, descb); -} - -extern "C" void pztrsm_(const char* side, const char* uplo, const char* transa, const char* diag, int* m, int* n, std::complex* alpha, std::complex* a, int* ia, int* ja, int* desca, std::complex* b, int* ib, int* jb, int* descb) -{ - slate_ptrsm(side, uplo, transa, diag, *m, *n, *alpha, a, *ia, *ja, desca, b, *ib, *jb, descb); -} - -// ----------------------------------------------------------------------------- - -// Type generic function calls the SLATE routine -template< typename scalar_t > -void slate_ptrsm(const char* sidestr, const char* uplostr, const char* transastr, const char* diagstr, int m, int n, scalar_t alpha, scalar_t* a, int ia, int ja, int* desca, scalar_t* b, int ib, int jb, int* descb) +//------------------------------------------------------------------------------ +/// SLATE ScaLAPACK wrapper sets up SLATE matrices from ScaLAPACK descriptors +/// and calls SLATE. +template +void slate_ptrsm( + const char* side_str, const char* uplo_str, const char* transA_str, + const char* diag_str, + blas_int m, blas_int n, scalar_t alpha, + scalar_t* A_data, blas_int ia, blas_int ja, blas_int const* descA, + scalar_t* B_data, blas_int ib, blas_int jb, blas_int const* descB) { Side side{}; Uplo uplo{}; Op transA{}; Diag diag{}; - from_string( std::string( 1, sidestr[0] ), &side ); - from_string( std::string( 1, uplostr[0] ), &uplo ); - from_string( std::string( 1, transastr[0] ), &transA ); - from_string( std::string( 1, diagstr[0] ), &diag ); + from_string( std::string( 1, side_str[0] ), &side ); + from_string( std::string( 1, uplo_str[0] ), &uplo ); + from_string( std::string( 1, transA_str[0] ), &transA ); + from_string( std::string( 1, diag_str[0] ), &diag ); slate::Target target = TargetConfig::value(); int verbose = VerboseConfig::value(); @@ -117,28 +43,93 @@ void slate_ptrsm(const char* sidestr, const char* uplostr, const char* transastr int64_t Bn = n; // create SLATE matrices from the ScaLAPACK layouts - int nprow, npcol, myprow, mypcol; - Cblacs_gridinfo(desc_CTXT(desca), &nprow, &npcol, &myprow, &mypcol); - auto AT = slate::TriangularMatrix::fromScaLAPACK(uplo, diag, desc_N(desca), a, desc_LLD(desca), desc_NB(desca), grid_order, nprow, npcol, MPI_COMM_WORLD); - AT = slate_scalapack_submatrix(Am, An, AT, ia, ja, desca); - - Cblacs_gridinfo(desc_CTXT(descb), &nprow, &npcol, &myprow, &mypcol); - auto B = slate::Matrix::fromScaLAPACK(desc_M(descb), desc_N(descb), b, desc_LLD(descb), desc_MB(descb), desc_NB(descb), grid_order, nprow, npcol, MPI_COMM_WORLD); - B = slate_scalapack_submatrix(Bm, Bn, B, ib, jb, descb); + blas_int nprow, npcol, myprow, mypcol; + Cblacs_gridinfo( desc_ctxt( descA ), &nprow, &npcol, &myprow, &mypcol ); + auto AT = slate::TriangularMatrix::fromScaLAPACK( + uplo, diag, desc_n( descA ), A_data, desc_lld( descA ), desc_nb( descA ), + grid_order, nprow, npcol, MPI_COMM_WORLD ); + AT = slate_scalapack_submatrix( Am, An, AT, ia, ja, descA ); + + Cblacs_gridinfo( desc_ctxt( descB ), &nprow, &npcol, &myprow, &mypcol ); + auto B = slate::Matrix::fromScaLAPACK( + desc_m( descB ), desc_n( descB ), B_data, desc_lld( descB ), + desc_mb( descB ), desc_nb( descB ), + grid_order, nprow, npcol, MPI_COMM_WORLD ); + B = slate_scalapack_submatrix( Bm, Bn, B, ib, jb, descB ); if (transA == Op::Trans) - AT = transpose(AT); + AT = transpose( AT ); else if (transA == Op::ConjTrans) AT = conj_transpose( AT ); if (verbose && myprow == 0 && mypcol == 0) logprintf("%s\n", "trsm"); - slate::trsm(side, alpha, AT, B, { + slate::trsm( side, alpha, AT, B, { {slate::Option::Lookahead, lookahead}, {slate::Option::Target, target} }); } +//------------------------------------------------------------------------------ +// Fortran interfaces +// Each Fortran interface calls the type generic slate wrapper. + +extern "C" { + +#define SCALAPACK_pdtrsm BLAS_FORTRAN_NAME( pdtrsm, PDTRSM ) +void SCALAPACK_pdtrsm( + const char* side, const char* uplo, const char* transA, const char* diag, + blas_int const* m, blas_int const* n, double* alpha, + double* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + double* B_data, blas_int const* ib, blas_int const* jb, blas_int const* descB ) +{ + slate_ptrsm( + side, uplo, transA, diag, *m, *n, *alpha, + A_data, *ia, *ja, descA, + B_data, *ib, *jb, descB ); +} + +#define SCALAPACK_pstrsm BLAS_FORTRAN_NAME( pstrsm, PSTRSM ) +void SCALAPACK_pstrsm( + const char* side, const char* uplo, const char* transA, const char* diag, + blas_int const* m, blas_int const* n, float* alpha, + float* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + float* B_data, blas_int const* ib, blas_int const* jb, blas_int const* descB ) +{ + slate_ptrsm( + side, uplo, transA, diag, *m, *n, *alpha, + A_data, *ia, *ja, descA, + B_data, *ib, *jb, descB ); +} + +#define SCALAPACK_pctrsm BLAS_FORTRAN_NAME( pctrsm, PCTRSM ) +void SCALAPACK_pctrsm( + const char* side, const char* uplo, const char* transA, const char* diag, + blas_int const* m, blas_int const* n, std::complex* alpha, + std::complex* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + std::complex* B_data, blas_int const* ib, blas_int const* jb, blas_int const* descB ) +{ + slate_ptrsm( + side, uplo, transA, diag, *m, *n, *alpha, + A_data, *ia, *ja, descA, + B_data, *ib, *jb, descB ); +} + +#define SCALAPACK_pztrsm BLAS_FORTRAN_NAME( pztrsm, PZTRSM ) +void SCALAPACK_pztrsm( + const char* side, const char* uplo, const char* transA, const char* diag, + blas_int const* m, blas_int const* n, std::complex* alpha, + std::complex* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA, + std::complex* B_data, blas_int const* ib, blas_int const* jb, blas_int const* descB ) +{ + slate_ptrsm( + side, uplo, transA, diag, *m, *n, *alpha, + A_data, *ia, *ja, descA, + B_data, *ib, *jb, descB ); +} + +} // extern "C" + } // namespace scalapack_api } // namespace slate