Skip to content

Commit

Permalink
scalapack API: style, move template before Fortran wrappers, wrap to …
Browse files Browse the repository at this point in the history
…80 chars.
  • Loading branch information
mgates3 committed Jan 14, 2025
1 parent 88f2beb commit 8c433e3
Show file tree
Hide file tree
Showing 30 changed files with 2,817 additions and 2,494 deletions.
176 changes: 91 additions & 85 deletions scalapack_api/scalapack_gecon.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,87 +8,22 @@
namespace slate {
namespace scalapack_api {

// -----------------------------------------------------------------------------
// Type generic function calls the SLATE routine
template< typename scalar_t >
void slate_pgecon(const char* normstr, int n, scalar_t* a, int ia, int ja, int* desca, blas::real_type<scalar_t> anorm, blas::real_type<scalar_t>* rcond, scalar_t* work, int lwork, void* irwork, int lirwork, int* info);

// -----------------------------------------------------------------------------
// C interfaces (FORTRAN_UPPER, FORTRAN_LOWER, FORTRAN_UNDERSCORE)
// Each C interface calls the type generic slate_pher2k

extern "C" void PSGECON(const char* normstr, int* n, float* a, int* ia, int* ja, int* desca, float* anorm, float* rcond, float* work, int* lwork, int* iwork, int* liwork, int* info)
{
slate_pgecon(normstr, *n, a, *ia, *ja, desca, *anorm, rcond, work, *lwork, iwork, *liwork, info);
}

extern "C" void psgecon(const char* normstr, int* n, float* a, int* ia, int* ja, int* desca, float* anorm, float* rcond, float* work, int* lwork, int* iwork, int* liwork, int* info)
{
slate_pgecon(normstr, *n, a, *ia, *ja, desca, *anorm, rcond, work, *lwork, iwork, *liwork, info);
}

extern "C" void psgecon_(const char* normstr, int* n, float* a, int* ia, int* ja, int* desca, float* anorm, float* rcond, float* work, int* lwork, int* iwork, int* liwork, int* info)
{
slate_pgecon(normstr, *n, a, *ia, *ja, desca, *anorm, rcond, work, *lwork, iwork, *liwork, info);
}

// -----------------------------------------------------------------------------

extern "C" void PDGECON(const char* normstr, int* n, double* a, int* ia, int* ja, int* desca, double* anorm, double* rcond, double* work, int* lwork, int* iwork, int* liwork, int* info)
{
slate_pgecon(normstr, *n, a, *ia, *ja, desca, *anorm, rcond, work, *lwork, iwork, *liwork, info);
}

extern "C" void pdgecon(const char* normstr, int* n, double* a, int* ia, int* ja, int* desca, double* anorm, double* rcond, double* work, int* lwork, int* iwork, int* liwork, int* info)
{
slate_pgecon(normstr, *n, a, *ia, *ja, desca, *anorm, rcond, work, *lwork, iwork, *liwork, info);
}

extern "C" void pdgecon_(const char* normstr, int* n, double* a, int* ia, int* ja, int* desca, double* anorm, double* rcond, double* work, int* lwork, int* iwork, int* liwork, int* info)
{
slate_pgecon(normstr, *n, a, *ia, *ja, desca, *anorm, rcond, work, *lwork, iwork, *liwork, info);
}

// -----------------------------------------------------------------------------

extern "C" void PCGECON(const char* normstr, int* n, std::complex<float>* a, int* ia, int* ja, int* desca, float* anorm, float* rcond, std::complex<float>* work, int* lwork, float* rwork, int* lrwork, int* info)
{
slate_pgecon(normstr, *n, a, *ia, *ja, desca, *anorm, rcond, work, *lwork, rwork, *lrwork, info);
}

extern "C" void pcgecon(const char* normstr, int* n, std::complex<float>* a, int* ia, int* ja, int* desca, float* anorm, float* rcond, std::complex<float>* work, int* lwork, float* rwork, int* lrwork, int* info)
{
slate_pgecon(normstr, *n, a, *ia, *ja, desca, *anorm, rcond, work, *lwork, rwork, *lrwork, info);
}

extern "C" void pcgecon_(const char* normstr, int* n, std::complex<float>* a, int* ia, int* ja, int* desca, float* anorm, float* rcond, std::complex<float>* work, int* lwork, float* rwork, int* lrwork, int* info)
{
slate_pgecon(normstr, *n, a, *ia, *ja, desca, *anorm, rcond, work, *lwork, rwork, *lrwork, info);
}

// -----------------------------------------------------------------------------

extern "C" void PZGECON(const char* normstr, int* n, std::complex<double>* a, int* ia, int* ja, int* desca, double* anorm, double* rcond, std::complex<double>* work, int* lwork, double* rwork, int* lrwork, int* info)
{
slate_pgecon(normstr, *n, a, *ia, *ja, desca, *anorm, rcond, work, *lwork, rwork, *lrwork, info);
}

extern "C" void pzgecon(const char* normstr, int* n, std::complex<double>* a, int* ia, int* ja, int* desca, double* anorm, double* rcond, std::complex<double>* work, int* lwork, double* rwork, int* lrwork, int* info)
{
slate_pgecon(normstr, *n, a, *ia, *ja, desca, *anorm, rcond, work, *lwork, rwork, *lrwork, info);
}

extern "C" void pzgecon_(const char* normstr, int* n, std::complex<double>* a, int* ia, int* ja, int* desca, double* anorm, double* rcond, std::complex<double>* work, int* lwork, double* rwork, int* lrwork, int* info)
{
slate_pgecon(normstr, *n, a, *ia, *ja, desca, *anorm, rcond, work, *lwork, rwork, *lrwork, info);
}

// -----------------------------------------------------------------------------
template< typename scalar_t >
void slate_pgecon(const char* normstr, int n, scalar_t* a, int ia, int ja, int* desca, blas::real_type<scalar_t> anorm, blas::real_type<scalar_t>* rcond, scalar_t* work, int lwork, void* irwork, int lirwork, int* info)
//------------------------------------------------------------------------------
/// SLATE ScaLAPACK wrapper sets up SLATE matrices from ScaLAPACK descriptors
/// and calls SLATE.
/// If scalar_t is real, irwork is integer.
/// If scalar_t is complex, irwork is real.
template <typename scalar_t>
void slate_pgecon(
const char* norm_str, blas_int n,
scalar_t* A_data, blas_int ia, blas_int ja, blas_int const* descA,
blas::real_type<scalar_t> Anorm, blas::real_type<scalar_t>* rcond,
scalar_t* work, blas_int lwork,
void* irwork, blas_int lirwork,
blas_int* info )
{
Norm norm{};
from_string( std::string( 1, normstr[0] ), &norm );
from_string( std::string( 1, norm_str[0] ), &norm );

slate::Target target = TargetConfig::value();
int verbose = VerboseConfig::value();
Expand All @@ -100,15 +35,15 @@ void slate_pgecon(const char* normstr, int n, scalar_t* a, int ia, int ja, int*
// todo: extract the real info from getrf
*info = 0;

int nprow, npcol, myprow, mypcol;
Cblacs_gridinfo(desc_CTXT(desca), &nprow, &npcol, &myprow, &mypcol);
blas_int nprow, npcol, myprow, mypcol;
Cblacs_gridinfo( desc_ctxt( descA ), &nprow, &npcol, &myprow, &mypcol );
if (verbose && myprow == 0 && mypcol == 0)
logprintf("%s\n", "gecon");

if (lwork == -1 || lirwork == -1) {
*work = 0;
if constexpr (std::is_same_v<scalar_t, blas::real_type<scalar_t>>) {
*(int*)irwork = 0;
*(blas_int*)irwork = 0;
}
else {
*(blas::real_type<scalar_t>*)irwork = 0;
Expand All @@ -121,16 +56,87 @@ void slate_pgecon(const char* normstr, int n, scalar_t* a, int ia, int ja, int*
int64_t An = n;

// create SLATE matrices from the ScaLAPACK layouts
auto A = slate::Matrix<scalar_t>::fromScaLAPACK(desc_M(desca), desc_N(desca), a, desc_LLD(desca), desc_MB(desca), desc_NB(desca), grid_order, nprow, npcol, MPI_COMM_WORLD);
A = slate_scalapack_submatrix(Am, An, A, ia, ja, desca);
auto A = slate::Matrix<scalar_t>::fromScaLAPACK(
desc_m( descA ), desc_n( descA ), A_data, desc_lld( descA ),
desc_mb( descA ), desc_nb( descA ),
grid_order, nprow, npcol, MPI_COMM_WORLD );
A = slate_scalapack_submatrix( Am, An, A, ia, ja, descA );

*rcond = slate::gecondest(norm, A, anorm, {
*rcond = slate::gecondest( norm, A, Anorm, {
{slate::Option::Lookahead, lookahead},
{slate::Option::Target, target},
{slate::Option::MaxPanelThreads, panel_threads},
{slate::Option::InnerBlocking, ib}
});
}

//------------------------------------------------------------------------------
// Fortran interfaces
// Each Fortran interface calls the type generic slate wrapper.

extern "C" {

#define SCALAPACK_psgecon BLAS_FORTRAN_NAME( psgecon, PSGECON )
void SCALAPACK_psgecon(
const char* norm, blas_int const* n,
float* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA,
float* Anorm, float* rcond,
float* work, blas_int const* lwork,
blas_int* iwork, blas_int const* liwork,
blas_int* info )
{
slate_pgecon(
norm, *n,
A_data, *ia, *ja, descA, *Anorm, rcond,
work, *lwork, iwork, *liwork, info );
}

#define SCALAPACK_pdgecon BLAS_FORTRAN_NAME( pdgecon, PDGECON )
void SCALAPACK_pdgecon(
const char* norm, blas_int const* n,
double* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA,
double* Anorm, double* rcond,
double* work, blas_int const* lwork,
blas_int* iwork, blas_int const* liwork,
blas_int* info )
{
slate_pgecon(
norm, *n,
A_data, *ia, *ja, descA, *Anorm, rcond,
work, *lwork, iwork, *liwork, info );
}

#define SCALAPACK_pcgecon BLAS_FORTRAN_NAME( pcgecon, PCGECON )
void SCALAPACK_pcgecon(
const char* norm, blas_int const* n,
std::complex<float>* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA,
float* Anorm, float* rcond,
std::complex<float>* work, blas_int const* lwork,
float* rwork, blas_int const* lrwork,
blas_int* info )
{
slate_pgecon(
norm, *n,
A_data, *ia, *ja, descA, *Anorm, rcond,
work, *lwork, rwork, *lrwork, info );
}

#define SCALAPACK_pzgecon BLAS_FORTRAN_NAME( pzgecon, PZGECON )
void SCALAPACK_pzgecon(
const char* norm, blas_int const* n,
std::complex<double>* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA,
double* Anorm, double* rcond,
std::complex<double>* work, blas_int const* lwork,
double* rwork, blas_int const* lrwork,
blas_int* info )
{
slate_pgecon(
norm, *n,
A_data, *ia, *ja, descA, *Anorm, rcond,
work, *lwork, rwork, *lrwork, info );
}

} // extern "C"

} // namespace scalapack_api
} // namespace slate
Loading

0 comments on commit 8c433e3

Please sign in to comment.