Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ScaLAPACK API style #210

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 91 additions & 89 deletions scalapack_api/scalapack_gecon.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,91 +8,22 @@
namespace slate {
namespace scalapack_api {

// -----------------------------------------------------------------------------

// Required CBLACS calls
extern "C" void Cblacs_gridinfo(int context, int* np_row, int* np_col, int* my_row, int* my_col);

// Type generic function calls the SLATE routine
template< typename scalar_t >
void slate_pgecon(const char* normstr, int n, scalar_t* a, int ia, int ja, int* desca, blas::real_type<scalar_t> anorm, blas::real_type<scalar_t>* rcond, scalar_t* work, int lwork, void* irwork, int lirwork, int* info);

// -----------------------------------------------------------------------------
// C interfaces (FORTRAN_UPPER, FORTRAN_LOWER, FORTRAN_UNDERSCORE)
// Each C interface calls the type generic slate_pher2k

extern "C" void PSGECON(const char* normstr, int* n, float* a, int* ia, int* ja, int* desca, float* anorm, float* rcond, float* work, int* lwork, int* iwork, int* liwork, int* info)
{
slate_pgecon(normstr, *n, a, *ia, *ja, desca, *anorm, rcond, work, *lwork, iwork, *liwork, info);
}

extern "C" void psgecon(const char* normstr, int* n, float* a, int* ia, int* ja, int* desca, float* anorm, float* rcond, float* work, int* lwork, int* iwork, int* liwork, int* info)
{
slate_pgecon(normstr, *n, a, *ia, *ja, desca, *anorm, rcond, work, *lwork, iwork, *liwork, info);
}

extern "C" void psgecon_(const char* normstr, int* n, float* a, int* ia, int* ja, int* desca, float* anorm, float* rcond, float* work, int* lwork, int* iwork, int* liwork, int* info)
{
slate_pgecon(normstr, *n, a, *ia, *ja, desca, *anorm, rcond, work, *lwork, iwork, *liwork, info);
}

// -----------------------------------------------------------------------------

extern "C" void PDGECON(const char* normstr, int* n, double* a, int* ia, int* ja, int* desca, double* anorm, double* rcond, double* work, int* lwork, int* iwork, int* liwork, int* info)
{
slate_pgecon(normstr, *n, a, *ia, *ja, desca, *anorm, rcond, work, *lwork, iwork, *liwork, info);
}

extern "C" void pdgecon(const char* normstr, int* n, double* a, int* ia, int* ja, int* desca, double* anorm, double* rcond, double* work, int* lwork, int* iwork, int* liwork, int* info)
{
slate_pgecon(normstr, *n, a, *ia, *ja, desca, *anorm, rcond, work, *lwork, iwork, *liwork, info);
}

extern "C" void pdgecon_(const char* normstr, int* n, double* a, int* ia, int* ja, int* desca, double* anorm, double* rcond, double* work, int* lwork, int* iwork, int* liwork, int* info)
{
slate_pgecon(normstr, *n, a, *ia, *ja, desca, *anorm, rcond, work, *lwork, iwork, *liwork, info);
}

// -----------------------------------------------------------------------------

extern "C" void PCGECON(const char* normstr, int* n, std::complex<float>* a, int* ia, int* ja, int* desca, float* anorm, float* rcond, std::complex<float>* work, int* lwork, float* rwork, int* lrwork, int* info)
{
slate_pgecon(normstr, *n, a, *ia, *ja, desca, *anorm, rcond, work, *lwork, rwork, *lrwork, info);
}

extern "C" void pcgecon(const char* normstr, int* n, std::complex<float>* a, int* ia, int* ja, int* desca, float* anorm, float* rcond, std::complex<float>* work, int* lwork, float* rwork, int* lrwork, int* info)
{
slate_pgecon(normstr, *n, a, *ia, *ja, desca, *anorm, rcond, work, *lwork, rwork, *lrwork, info);
}

extern "C" void pcgecon_(const char* normstr, int* n, std::complex<float>* a, int* ia, int* ja, int* desca, float* anorm, float* rcond, std::complex<float>* work, int* lwork, float* rwork, int* lrwork, int* info)
{
slate_pgecon(normstr, *n, a, *ia, *ja, desca, *anorm, rcond, work, *lwork, rwork, *lrwork, info);
}

// -----------------------------------------------------------------------------

extern "C" void PZGECON(const char* normstr, int* n, std::complex<double>* a, int* ia, int* ja, int* desca, double* anorm, double* rcond, std::complex<double>* work, int* lwork, double* rwork, int* lrwork, int* info)
{
slate_pgecon(normstr, *n, a, *ia, *ja, desca, *anorm, rcond, work, *lwork, rwork, *lrwork, info);
}

extern "C" void pzgecon(const char* normstr, int* n, std::complex<double>* a, int* ia, int* ja, int* desca, double* anorm, double* rcond, std::complex<double>* work, int* lwork, double* rwork, int* lrwork, int* info)
{
slate_pgecon(normstr, *n, a, *ia, *ja, desca, *anorm, rcond, work, *lwork, rwork, *lrwork, info);
}

extern "C" void pzgecon_(const char* normstr, int* n, std::complex<double>* a, int* ia, int* ja, int* desca, double* anorm, double* rcond, std::complex<double>* work, int* lwork, double* rwork, int* lrwork, int* info)
{
slate_pgecon(normstr, *n, a, *ia, *ja, desca, *anorm, rcond, work, *lwork, rwork, *lrwork, info);
}

// -----------------------------------------------------------------------------
template< typename scalar_t >
void slate_pgecon(const char* normstr, int n, scalar_t* a, int ia, int ja, int* desca, blas::real_type<scalar_t> anorm, blas::real_type<scalar_t>* rcond, scalar_t* work, int lwork, void* irwork, int lirwork, int* info)
//------------------------------------------------------------------------------
/// SLATE ScaLAPACK wrapper sets up SLATE matrices from ScaLAPACK descriptors
/// and calls SLATE.
/// If scalar_t is real, irwork is integer.
/// If scalar_t is complex, irwork is real.
template <typename scalar_t>
void slate_pgecon(
const char* norm_str, blas_int n,
scalar_t* A_data, blas_int ia, blas_int ja, blas_int const* descA,
blas::real_type<scalar_t> Anorm, blas::real_type<scalar_t>* rcond,
scalar_t* work, blas_int lwork,
void* irwork, blas_int lirwork,
blas_int* info )
{
Norm norm{};
from_string( std::string( 1, normstr[0] ), &norm );
from_string( std::string( 1, norm_str[0] ), &norm );

slate::Target target = TargetConfig::value();
int verbose = VerboseConfig::value();
Expand All @@ -104,15 +35,15 @@ void slate_pgecon(const char* normstr, int n, scalar_t* a, int ia, int ja, int*
// todo: extract the real info from getrf
*info = 0;

int nprow, npcol, myprow, mypcol;
Cblacs_gridinfo(desc_CTXT(desca), &nprow, &npcol, &myprow, &mypcol);
blas_int nprow, npcol, myprow, mypcol;
Cblacs_gridinfo( desc_ctxt( descA ), &nprow, &npcol, &myprow, &mypcol );
if (verbose && myprow == 0 && mypcol == 0)
logprintf("%s\n", "gecon");

if (lwork == -1 || lirwork == -1) {
*work = 0;
if constexpr (std::is_same_v<scalar_t, blas::real_type<scalar_t>>) {
*(int*)irwork = 0;
*(blas_int*)irwork = 0;
}
else {
*(blas::real_type<scalar_t>*)irwork = 0;
Expand All @@ -125,16 +56,87 @@ void slate_pgecon(const char* normstr, int n, scalar_t* a, int ia, int ja, int*
int64_t An = n;

// create SLATE matrices from the ScaLAPACK layouts
auto A = slate::Matrix<scalar_t>::fromScaLAPACK(desc_M(desca), desc_N(desca), a, desc_LLD(desca), desc_MB(desca), desc_NB(desca), grid_order, nprow, npcol, MPI_COMM_WORLD);
A = slate_scalapack_submatrix(Am, An, A, ia, ja, desca);
auto A = slate::Matrix<scalar_t>::fromScaLAPACK(
desc_m( descA ), desc_n( descA ), A_data, desc_lld( descA ),
desc_mb( descA ), desc_nb( descA ),
grid_order, nprow, npcol, MPI_COMM_WORLD );
A = slate_scalapack_submatrix( Am, An, A, ia, ja, descA );

*rcond = slate::gecondest(norm, A, anorm, {
*rcond = slate::gecondest( norm, A, Anorm, {
{slate::Option::Lookahead, lookahead},
{slate::Option::Target, target},
{slate::Option::MaxPanelThreads, panel_threads},
{slate::Option::InnerBlocking, ib}
});
}

//------------------------------------------------------------------------------
// Fortran interfaces
// Each Fortran interface calls the type generic slate wrapper.

extern "C" {

#define SCALAPACK_psgecon BLAS_FORTRAN_NAME( psgecon, PSGECON )
void SCALAPACK_psgecon(
const char* norm, blas_int const* n,
float* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA,
float* Anorm, float* rcond,
float* work, blas_int const* lwork,
blas_int* iwork, blas_int const* liwork,
blas_int* info )
{
slate_pgecon(
norm, *n,
A_data, *ia, *ja, descA, *Anorm, rcond,
work, *lwork, iwork, *liwork, info );
}

#define SCALAPACK_pdgecon BLAS_FORTRAN_NAME( pdgecon, PDGECON )
void SCALAPACK_pdgecon(
const char* norm, blas_int const* n,
double* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA,
double* Anorm, double* rcond,
double* work, blas_int const* lwork,
blas_int* iwork, blas_int const* liwork,
blas_int* info )
{
slate_pgecon(
norm, *n,
A_data, *ia, *ja, descA, *Anorm, rcond,
work, *lwork, iwork, *liwork, info );
}

#define SCALAPACK_pcgecon BLAS_FORTRAN_NAME( pcgecon, PCGECON )
void SCALAPACK_pcgecon(
const char* norm, blas_int const* n,
std::complex<float>* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA,
float* Anorm, float* rcond,
std::complex<float>* work, blas_int const* lwork,
float* rwork, blas_int const* lrwork,
blas_int* info )
{
slate_pgecon(
norm, *n,
A_data, *ia, *ja, descA, *Anorm, rcond,
work, *lwork, rwork, *lrwork, info );
}

#define SCALAPACK_pzgecon BLAS_FORTRAN_NAME( pzgecon, PZGECON )
void SCALAPACK_pzgecon(
const char* norm, blas_int const* n,
std::complex<double>* A_data, blas_int const* ia, blas_int const* ja, blas_int const* descA,
double* Anorm, double* rcond,
std::complex<double>* work, blas_int const* lwork,
double* rwork, blas_int const* lrwork,
blas_int* info )
{
slate_pgecon(
norm, *n,
A_data, *ia, *ja, descA, *Anorm, rcond,
work, *lwork, rwork, *lrwork, info );
}

} // extern "C"

} // namespace scalapack_api
} // namespace slate
Loading