Skip to content

Commit

Permalink
Merge pull request #149 from neil-lindquist/la-scala-apis
Browse files Browse the repository at this point in the history
Add SVD, Hermitian Eig, and condest to LAPACK and ScaLAPACK APIS
  • Loading branch information
neil-lindquist authored Dec 14, 2023
2 parents 8693751 + a1efb62 commit f9ac8c7
Show file tree
Hide file tree
Showing 15 changed files with 1,801 additions and 16 deletions.
12 changes: 12 additions & 0 deletions GNUmakefile
Original file line number Diff line number Diff line change
Expand Up @@ -1104,26 +1104,32 @@ scalapack_api_so = lib/libslate_scalapack_api.so
scalapack_api = lib/libslate_scalapack_api.$(lib_ext)

scalapack_api_src += \
scalapack_api/scalapack_gecon.cc \
scalapack_api/scalapack_gels.cc \
scalapack_api/scalapack_gemm.cc \
scalapack_api/scalapack_gesv.cc \
scalapack_api/scalapack_gesv_mixed.cc \
scalapack_api/scalapack_gesvd.cc \
scalapack_api/scalapack_getrf.cc \
scalapack_api/scalapack_getrs.cc \
scalapack_api/scalapack_heev.cc \
scalapack_api/scalapack_heevd.cc \
scalapack_api/scalapack_hemm.cc \
scalapack_api/scalapack_her2k.cc \
scalapack_api/scalapack_herk.cc \
scalapack_api/scalapack_lange.cc \
scalapack_api/scalapack_lanhe.cc \
scalapack_api/scalapack_lansy.cc \
scalapack_api/scalapack_lantr.cc \
scalapack_api/scalapack_pocon.cc \
scalapack_api/scalapack_posv.cc \
scalapack_api/scalapack_potrf.cc \
scalapack_api/scalapack_potri.cc \
scalapack_api/scalapack_potrs.cc \
scalapack_api/scalapack_symm.cc \
scalapack_api/scalapack_syr2k.cc \
scalapack_api/scalapack_syrk.cc \
scalapack_api/scalapack_trcon.cc \
scalapack_api/scalapack_trmm.cc \
scalapack_api/scalapack_trsm.cc \
# End. Add alphabetically.
Expand Down Expand Up @@ -1162,26 +1168,32 @@ lapack_api_so = lib/libslate_lapack_api.so
lapack_api = lib/libslate_lapack_api.$(lib_ext)

lapack_api_src += \
lapack_api/lapack_gecon.cc \
lapack_api/lapack_gels.cc \
lapack_api/lapack_gemm.cc \
lapack_api/lapack_gesv.cc \
lapack_api/lapack_gesv_mixed.cc \
lapack_api/lapack_gesvd.cc \
lapack_api/lapack_getrf.cc \
lapack_api/lapack_getri.cc \
lapack_api/lapack_getrs.cc \
lapack_api/lapack_heev.cc \
lapack_api/lapack_heevd.cc \
lapack_api/lapack_hemm.cc \
lapack_api/lapack_her2k.cc \
lapack_api/lapack_herk.cc \
lapack_api/lapack_lange.cc \
lapack_api/lapack_lanhe.cc \
lapack_api/lapack_lansy.cc \
lapack_api/lapack_lantr.cc \
lapack_api/lapack_pocon.cc \
lapack_api/lapack_posv.cc \
lapack_api/lapack_potrf.cc \
lapack_api/lapack_potri.cc \
lapack_api/lapack_symm.cc \
lapack_api/lapack_syr2k.cc \
lapack_api/lapack_syrk.cc \
lapack_api/lapack_trcon.cc \
lapack_api/lapack_trmm.cc \
lapack_api/lapack_trsm.cc \
# End. Add alphabetically.
Expand Down
117 changes: 117 additions & 0 deletions lapack_api/lapack_gecon.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
//------------------------------------------------------------------------------
// Copyright (c) 2017-2023, University of Tennessee
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of the University of Tennessee nor the
// names of its contributors may be used to endorse or promote products
// derived from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
// ARE DISCLAIMED. IN NO EVENT SHALL UNIVERSITY OF TENNESSEE BE LIABLE FOR ANY
// DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//------------------------------------------------------------------------------
// This research was supported by the Exascale Computing Project (17-SC-20-SC),
// a collaborative effort of two U.S. Department of Energy organizations (Office
// of Science and the National Nuclear Security Administration) responsible for
// the planning and preparation of a capable exascale ecosystem, including
// software, applications, hardware, advanced system engineering and early
// testbed platforms, in support of the nation's exascale computing imperative.
//------------------------------------------------------------------------------
// For assistance with SLATE, email <[email protected]>.
// You can also join the "SLATE User" Google group by going to
// https://groups.google.com/a/icl.utk.edu/forum/#!forum/slate-user,
// signing in with your Google credentials, and then clicking "Join group".
//------------------------------------------------------------------------------

#include "lapack_slate.hh"

namespace slate {
namespace lapack_api {

// -----------------------------------------------------------------------------
// Local function
template <typename scalar_t>
void slate_gecon(const char* normstr, const int n, scalar_t* a, const int lda, blas::real_type<scalar_t> Anorm, blas::real_type<scalar_t>* rcond, scalar_t* work, int* iwork, int* info);

// -----------------------------------------------------------------------------
// C interfaces (FORTRAN_UPPER, FORTRAN_LOWER, FORTRAN_UNDERSCORE)

#define slate_sgecon BLAS_FORTRAN_NAME( slate_sgecon, SLATE_SGECON )
#define slate_dgecon BLAS_FORTRAN_NAME( slate_dgecon, SLATE_DGECON )
#define slate_cgecon BLAS_FORTRAN_NAME( slate_cgecon, SLATE_CGECON )
#define slate_zgecon BLAS_FORTRAN_NAME( slate_zgecon, SLATE_ZGECON )

extern "C" void slate_sgecon(const char* normstr, const int* n, float* a, const int* lda, float* Anorm, float* rcond, float* work, int* iwork, int* info)
{
slate_gecon(normstr, *n, a, *lda, *Anorm, rcond, work, iwork, info);
}
extern "C" void slate_dgecon(const char* normstr, const int* n, double* a, const int* lda, double* Anorm, double* rcond, double* work, int* iwork, int* info)
{
slate_gecon(normstr, *n, a, *lda, *Anorm, rcond, work, iwork, info);
}
extern "C" void slate_cgecon(const char* normstr, const int* n, std::complex<float>* a, const int* lda, float* Anorm, float* rcond, std::complex<float>* work, int* iwork, int* info)
{
slate_gecon(normstr, *n, a, *lda, *Anorm, rcond, work, iwork, info);
}
extern "C" void slate_zgecon(const char* normstr, const int* n, std::complex<double>* a, const int* lda, double* Anorm, double* rcond, std::complex<double>* work, int* iwork, int* info)
{
slate_gecon(normstr, *n, a, *lda, *Anorm, rcond, work, iwork, info);
}

// -----------------------------------------------------------------------------

// Type generic function calls the SLATE routine
template <typename scalar_t>
void slate_gecon(const char* normstr, const int n, scalar_t* a, const int lda, blas::real_type<scalar_t> Anorm, blas::real_type<scalar_t>* rcond, scalar_t* work, int* iwork, int* info)
{
// Start timing
static int verbose = slate_lapack_set_verbose();
double timestart = 0.0;
if (verbose) timestart = omp_get_wtime();

// Check and initialize MPI, else SLATE calls to MPI will fail
int initialized, provided;
MPI_Initialized(&initialized);
if (! initialized)
MPI_Init_thread(nullptr, nullptr, MPI_THREAD_MULTIPLE, &provided);

int64_t lookahead = 1;
int64_t p = 1;
int64_t q = 1;
static slate::Target target = slate_lapack_set_target();

// sizes
lapack::Norm norm = lapack::char2norm(normstr[0]);
static int64_t nb = slate_lapack_set_nb(target);

// create SLATE matrix from the LAPACK data
auto A = slate::Matrix<scalar_t>::fromLAPACK(n, n, a, lda, nb, p, q, MPI_COMM_WORLD);

// solve
*rcond = slate::gecondest( norm, A, Anorm, {
{slate::Option::Lookahead, lookahead},
{slate::Option::Target, target}
});

// todo: get a real value for info
*info = 0;

if (verbose) std::cout << "slate_lapack_api: " << slate_lapack_scalar_t_to_char(a) << "gecon(" << normstr[0] << "," << n << "," << (void*)a << "," << lda << "," << Anorm << "," << (void*)rcond << "," << (void*)work << "," << (void*)iwork << "," << *info << ") " << (omp_get_wtime()-timestart) << " sec " << "nb:" << nb << " max_threads:" << omp_get_max_threads() << "\n";
}

} // namespace lapack_api
} // namespace slate
196 changes: 196 additions & 0 deletions lapack_api/lapack_gesvd.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
//------------------------------------------------------------------------------
// Copyright (c) 2017-2023, University of Tennessee
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of the University of Tennessee nor the
// names of its contributors may be used to endorse or promote products
// derived from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
// ARE DISCLAIMED. IN NO EVENT SHALL UNIVERSITY OF TENNESSEE BE LIABLE FOR ANY
// DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//------------------------------------------------------------------------------
// This research was supported by the Exascale Computing Project (17-SC-20-SC),
// a collaborative effort of two U.S. Department of Energy organizations (Office
// of Science and the National Nuclear Security Administration) responsible for
// the planning and preparation of a capable exascale ecosystem, including
// software, applications, hardware, advanced system engineering and early
// testbed platforms, in support of the nation's exascale computing imperative.
//------------------------------------------------------------------------------
// For assistance with SLATE, email <[email protected]>.
// You can also join the "SLATE User" Google group by going to
// https://groups.google.com/a/icl.utk.edu/forum/#!forum/slate-user,
// signing in with your Google credentials, and then clicking "Join group".
//------------------------------------------------------------------------------

#include "lapack_slate.hh"

namespace slate {
namespace lapack_api {

// -----------------------------------------------------------------------------
// Local function
template <typename scalar_t>
void slate_gesvd(const char* jobustr, const char* jobvtstr, const int m, const int n, scalar_t* a, const int lda, blas::real_type<scalar_t>* s, scalar_t* u, const int ldu, scalar_t* vt, const int ldvt, scalar_t* work, const int lwork, int* info);

// -----------------------------------------------------------------------------
// C interfaces (FORTRAN_UPPER, FORTRAN_LOWER, FORTRAN_UNDERSCORE)

#define slate_sgesvd BLAS_FORTRAN_NAME( slate_sgesvd, SLATE_SGESVD )
#define slate_dgesvd BLAS_FORTRAN_NAME( slate_dgesvd, SLATE_DGESVD )
#define slate_cgesvd BLAS_FORTRAN_NAME( slate_cgesvd, SLATE_CGESVD )
#define slate_zgesvd BLAS_FORTRAN_NAME( slate_zgesvd, SLATE_ZGESVD )

extern "C" void slate_sgesvd(const char* jobustr, const char* jobvtstr, const int* m, const int* n, float* a, const int* lda, float* s, float* u, const int* ldu, float* vt, const int* ldvt, float* work, const int* lwork, int* info)
{
slate_gesvd(jobustr, jobvtstr, *m, *n, a, *lda, s, u, *ldu, vt, *ldvt, work, *lwork, info);
}
extern "C" void slate_dgesvd(const char* jobustr, const char* jobvtstr, const int* m, const int* n, double* a, const int* lda, double* s, double* u, const int* ldu, double* vt, const int* ldvt, double* work, const int* lwork, int* info)
{
slate_gesvd(jobustr, jobvtstr, *m, *n, a, *lda, s, u, *ldu, vt, *ldvt, work, *lwork, info);
}
extern "C" void slate_cgesvd(const char* jobustr, const char* jobvtstr, const int* m, const int* n, std::complex<float>* a, const int* lda, float* s, std::complex<float>* u, const int* ldu, std::complex<float>* vt, const int* ldvt, std::complex<float>* work, const int* lwork, float* rwork, int* info)
{
slate_gesvd(jobustr, jobvtstr, *m, *n, a, *lda, s, u, *ldu, vt, *ldvt, work, *lwork, info);
}
extern "C" void slate_zgesvd(const char* jobustr, const char* jobvtstr, const int* m, const int* n, std::complex<double>* a, const int* lda, double* s, std::complex<double>* u, const int* ldu, std::complex<double>* vt, const int* ldvt, std::complex<double>* work, const int* lwork, double* rwork, int* info)
{
slate_gesvd(jobustr, jobvtstr, *m, *n, a, *lda, s, u, *ldu, vt, *ldvt, work, *lwork, info);
}

// -----------------------------------------------------------------------------

// Type generic function calls the SLATE routine
template <typename scalar_t>
void slate_gesvd(const char* jobustr, const char* jobvtstr, const int m, const int n, scalar_t* a, const int lda, blas::real_type<scalar_t>* s, scalar_t* u, const int ldu, scalar_t* vt, const int ldvt, scalar_t* work, const int lwork, int* info)
{
// Start timing
static int verbose = slate_lapack_set_verbose();
double timestart = 0.0;
if (verbose) timestart = omp_get_wtime();

// sizes
static slate::Target target = slate_lapack_set_target();
static int64_t nb = slate_lapack_set_nb(target);
int64_t min_mn = std::min( m, n );

// TODO check args more carefully
*info = 0;

if (lwork == -1) {
if (jobustr[0] == 'O') {
work[0] = m * min_mn;
}
else if (jobvtstr[0] == 'O') {
work[0] = min_mn * n;
}
else {
work[0] = 0;
}
}
else {
// Check and initialize MPI, else SLATE calls to MPI will fail
int initialized, provided;
MPI_Initialized(&initialized);
if (! initialized)
MPI_Init_thread(nullptr, nullptr, MPI_THREAD_MULTIPLE, &provided);

int64_t lookahead = 1;
int64_t p = 1;
int64_t q = 1;

// create SLATE matrix from the LAPACK data
auto A = slate::Matrix<scalar_t>::fromLAPACK(m, n, a, lda, nb, p, q, MPI_COMM_WORLD);
std::vector< blas::real_type<scalar_t> > Sigma_( min_mn );

slate::Matrix<scalar_t> U;
switch (jobustr[0]) {
case 'A':
U = slate::Matrix<scalar_t>::fromLAPACK(m, m, u, ldu, nb, p, q, MPI_COMM_WORLD);
break;
case 'S':
U = slate::Matrix<scalar_t>::fromLAPACK(m, min_mn, u, ldu, nb, p, q, MPI_COMM_WORLD);
break;
case 'O':
if (lwork >= m * min_mn) {
U = slate::Matrix<scalar_t>::fromLAPACK(m, min_mn, work, m, nb, p, q, MPI_COMM_WORLD);
}
else {
U = slate::Matrix<scalar_t>(m, min_mn, nb, p, q, MPI_COMM_WORLD);
U.insertLocalTiles(target);
}
break;
case 'N':
// Leave U empty
break;
default:
*info = 1;
}

slate::Matrix<scalar_t> VT;
switch (jobvtstr[0]) {
case 'A':
VT = slate::Matrix<scalar_t>::fromLAPACK(n, n, vt, ldvt, nb, p, q, MPI_COMM_WORLD);
break;
case 'S':
VT = slate::Matrix<scalar_t>::fromLAPACK(min_mn, n, vt, ldvt, nb, p, q, MPI_COMM_WORLD);
break;
case 'O':
if (lwork >= min_mn * n) {
VT = slate::Matrix<scalar_t>::fromLAPACK(min_mn, n, work, m, nb, p, q, MPI_COMM_WORLD);
}
else {
VT = slate::Matrix<scalar_t>(min_mn, n, nb, p, q, MPI_COMM_WORLD);
VT.insertLocalTiles(target);
}
break;
case 'N':
// Leave VT empty
break;
default:
*info = 2;
}

if (*info == 0) {
// solve
slate::svd( A, Sigma_, U, VT, {
{slate::Option::Lookahead, lookahead},
{slate::Option::Target, target}
});

std::copy(Sigma_.begin(), Sigma_.end(), s);

if (jobustr[0] == 'O') {
auto A_slice = A.slice( 0, m-1, 0, min_mn-1 );
slate::copy( U, A_slice, {
{slate::Option::Target, target}
});
}
if (jobvtstr[0] == 'O') {
auto A_slice = A.slice( 0, n-1, 0, min_mn-1 );
slate::copy( VT, A_slice, {
{slate::Option::Target, target}
});
}
}
}

if (verbose) std::cout << "slate_lapack_api: " << slate_lapack_scalar_t_to_char(a) << "gesvd(" << jobustr[0] << "," << jobvtstr[0] << "," << m << "," << n << "," << (void*)a << "," << lda << "," << (void*)s << "," << (void*)u << "," << ldu << "," << (void*)vt << "," << ldvt << "," << (void*)work << "," << lwork << "," << *info << ") " << (omp_get_wtime()-timestart) << " sec " << "nb:" << nb << " max_threads:" << omp_get_max_threads() << "\n";
}

} // namespace lapack_api
} // namespace slate
Loading

0 comments on commit f9ac8c7

Please sign in to comment.