Skip to content

Commit

Permalink
Merge pull request #153 from neil-lindquist/optimize-struct-sizes
Browse files Browse the repository at this point in the history
Optimize class sizes
  • Loading branch information
neil-lindquist authored Dec 13, 2023
2 parents d514136 + f0b465c commit 979f9d1
Show file tree
Hide file tree
Showing 9 changed files with 80 additions and 103 deletions.
4 changes: 2 additions & 2 deletions include/slate/BaseBandMatrix.hh
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ void BaseBandMatrix<scalar_t>::allocateBatchArrays(
int64_t batch_size, int64_t num_arrays)
{
if (batch_size == 0) {
for (int device = 0; device < this->num_devices_; ++device)
for (int device = 0; device < this->num_devices(); ++device)
batch_size = std::max(batch_size, getMaxDeviceTiles(device));
}
this->storage_->allocateBatchArrays(batch_size, num_arrays);
Expand All @@ -264,7 +264,7 @@ template <typename scalar_t>
void BaseBandMatrix<scalar_t>::reserveDeviceWorkspace()
{
int64_t num_tiles = 0;
for (int device = 0; device < this->num_devices_; ++device)
for (int device = 0; device < this->num_devices(); ++device)
num_tiles = std::max(num_tiles, getMaxDeviceTiles(device));
this->storage_->reserveDeviceWorkspace(num_tiles);
}
Expand Down
19 changes: 2 additions & 17 deletions include/slate/BaseMatrix.hh
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ public:
}

/// Returns number of devices (per MPI process) to distribute matrix to.
int num_devices() const { return num_devices_; }
int num_devices() const { return MatrixStorage<scalar_t>::num_devices(); }

void gridinfo( GridOrder* order, int* nprow, int* npcol,
int* myrow, int* mycol );
Expand Down Expand Up @@ -774,9 +774,6 @@ protected:
/// shared storage of tiles and buffers
std::shared_ptr< MatrixStorage<scalar_t> > storage_;

// ----- consider where to put these, here or in MatrixStorage
static int num_devices_;

MPI_Comm mpi_comm_;
MPI_Group mpi_group_;
int mpi_rank_;
Expand Down Expand Up @@ -888,10 +885,6 @@ BaseMatrix<scalar_t>::BaseMatrix(
MPI_Comm_rank(mpi_comm_, &mpi_rank_));
slate_mpi_call(
MPI_Comm_group(mpi_comm_, &mpi_group_));

// todo: these are static, but we (re-)initialize with each matrix.
// todo: similar code in BaseMatrix(...) and MatrixStorage(...)
num_devices_ = storage_->num_devices_;
}

//------------------------------------------------------------------------------
Expand Down Expand Up @@ -955,10 +948,6 @@ BaseMatrix<scalar_t>::BaseMatrix(
MPI_Comm_rank(mpi_comm_, &mpi_rank_));
slate_mpi_call(
MPI_Comm_group(mpi_comm_, &mpi_group_));

// todo: these are static, but we (re-)initialize with each matrix.
// todo: similar code in BaseMatrix(...) and MatrixStorage(...)
num_devices_ = storage_->num_devices_;
}

//------------------------------------------------------------------------------
Expand Down Expand Up @@ -3894,7 +3883,7 @@ std::tuple<int64_t, int64_t, int>
assert(0 <= j && j < nt());
// Given AnyDevice = -3, AllDevices = -2, HostNum = -1,
// GPU devices 0, 1, ..., num_devices-1.
assert( AnyDevice <= device && device < num_devices_ );
assert( AnyDevice <= device && device < num_devices() );
if (op_ == Op::NoTrans)
return { ioffset_ + i, joffset_ + j, device };
else
Expand Down Expand Up @@ -4001,10 +3990,6 @@ void BaseMatrix<scalar_t>::releaseRemoteWorkspace(
}
}

//------------------------------------------------------------------------------
template <typename scalar_t>
int BaseMatrix<scalar_t>::num_devices_ = 0;

//------------------------------------------------------------------------------
// from ScaLAPACK's indxg2l
// todo: where to put utilities like this?
Expand Down
4 changes: 2 additions & 2 deletions include/slate/BaseTrapezoidMatrix.hh
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,7 @@ void BaseTrapezoidMatrix<scalar_t>::allocateBatchArrays(
int64_t batch_size, int64_t num_arrays)
{
if (batch_size == 0) {
for (int device = 0; device < this->num_devices_; ++device)
for (int device = 0; device < this->num_devices(); ++device)
batch_size = std::max(batch_size, getMaxDeviceTiles(device));
}
this->storage_->allocateBatchArrays(batch_size, num_arrays);
Expand All @@ -635,7 +635,7 @@ template <typename scalar_t>
void BaseTrapezoidMatrix<scalar_t>::reserveDeviceWorkspace()
{
int64_t num_tiles = 0;
for (int device = 0; device < this->num_devices_; ++device)
for (int device = 0; device < this->num_devices(); ++device)
num_tiles = std::max(num_tiles, getMaxDeviceTiles(device));
this->storage_->reserveDeviceWorkspace(num_tiles);
}
Expand Down
4 changes: 2 additions & 2 deletions include/slate/Matrix.hh
Original file line number Diff line number Diff line change
Expand Up @@ -726,7 +726,7 @@ void Matrix<scalar_t>::allocateBatchArrays(
int64_t batch_size, int64_t num_arrays)
{
if (batch_size == 0) {
for (int device = 0; device < this->num_devices_; ++device)
for (int device = 0; device < this->num_devices(); ++device)
batch_size = std::max(batch_size, getMaxDeviceTiles(device));
}
this->storage_->allocateBatchArrays(batch_size, num_arrays);
Expand All @@ -746,7 +746,7 @@ template <typename scalar_t>
void Matrix<scalar_t>::reserveDeviceWorkspace()
{
int64_t num_tiles = 0;
for (int device = 0; device < this->num_devices_; ++device)
for (int device = 0; device < this->num_devices(); ++device)
num_tiles = std::max(num_tiles, getMaxDeviceTiles(device));
this->storage_->reserveDeviceWorkspace(num_tiles);
}
Expand Down
38 changes: 20 additions & 18 deletions include/slate/Tile.hh
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,10 @@ MatrixType conjTranspose( MatrixType&& A )
/// and who owns (allocated, deallocates) the data.
/// @ingroup enum
///
enum class TileKind {
Workspace, ///< SLATE allocated workspace tile
SlateOwned, ///< SLATE allocated origin tile
UserOwned, ///< User owned origin tile
enum class TileKind : char {
Workspace = 'w', ///< SLATE allocated workspace tile
SlateOwned = 'o', ///< SLATE allocated origin tile
UserOwned = 'u', ///< User owned origin tile
};

//------------------------------------------------------------------------------
Expand Down Expand Up @@ -179,6 +179,7 @@ public:
int64_t stride() const { return stride_; }

/// Sets column stride of this tile
[[deprecated( "Use setLayout to manage the Tile's layout. Will be removed 2024-12." )]]
void stride(int64_t in_stride) { stride_ = in_stride; }

/// Returns const pointer to data, i.e., A(0,0), where A is this tile
Expand Down Expand Up @@ -224,6 +225,12 @@ public:
/// e.g., via a fromScaLAPACK constructor.
bool allocated() const { return kind_ != TileKind::UserOwned; }

/// Returns the TileKind of this tile
TileKind kind()
{
return kind_;
}

/// Returns number of bytes; but NOT consecutive if stride != mb_.
size_t bytes() const { return sizeof(scalar_t) * size(); }

Expand Down Expand Up @@ -371,11 +378,6 @@ protected:
void nb(int64_t in_nb);
void offset(int64_t i, int64_t j);

void kind(TileKind kind)
{
kind_ = kind;
}

void state(MOSI_State stateIn)
{
switch (stateIn) {
Expand Down Expand Up @@ -405,13 +407,13 @@ protected:
int64_t stride_;
int64_t user_stride_; // Temporarily store user-provided-memory's stride

Op op_;
Uplo uplo_;

scalar_t* data_;
scalar_t* user_data_; // Temporarily point to user-provided memory buffer.
scalar_t* ext_data_; // Points to auxiliary buffer.

Op op_;
Uplo uplo_;

TileKind kind_;
/// layout_: The physical ordering of elements in the data buffer:
/// - ColMajor: elements of a column are 1-strided
Expand All @@ -434,11 +436,11 @@ Tile<scalar_t>::Tile()
nb_(0),
stride_(0),
user_stride_(0),
op_(Op::NoTrans),
uplo_(Uplo::General),
data_(nullptr),
user_data_(nullptr),
ext_data_(nullptr),
op_(Op::NoTrans),
uplo_(Uplo::General),
kind_(TileKind::UserOwned),
layout_(Layout::ColMajor),
user_layout_(Layout::ColMajor),
Expand Down Expand Up @@ -487,11 +489,11 @@ Tile<scalar_t>::Tile(
nb_(nb),
stride_(lda),
user_stride_(lda),
op_(Op::NoTrans),
uplo_(Uplo::General),
data_(A),
user_data_(A),
ext_data_(nullptr),
op_(Op::NoTrans),
uplo_(Uplo::General),
kind_(kind),
layout_(layout),
user_layout_(layout),
Expand Down Expand Up @@ -530,11 +532,11 @@ Tile<scalar_t>::Tile(
nb_(src_tile.nb_),
stride_(lda),
user_stride_(lda),
op_(src_tile.op_),
uplo_(src_tile.uplo_),
data_(A),
user_data_(A),
ext_data_(nullptr),
op_(src_tile.op_),
uplo_(src_tile.uplo_),
kind_(kind),
layout_(src_tile.layout_),
user_layout_(src_tile.user_layout_),
Expand Down
Loading

0 comments on commit 979f9d1

Please sign in to comment.