Skip to content

Commit

Permalink
makes the slices like more idiomatics
Browse files Browse the repository at this point in the history
  • Loading branch information
gbin committed Jan 17, 2025
1 parent 01f76e9 commit ad6fc8b
Showing 1 changed file with 97 additions and 132 deletions.
229 changes: 97 additions & 132 deletions core/cu29_runtime/src/pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,8 @@ impl<T> ElementType for T where
{
}

pub trait ArrayLike: Debug {
pub trait ArrayLike: Deref<Target = [Self::Element]> + DerefMut + Debug {
type Element: ElementType;
fn slice(&self) -> &[Self::Element];
fn slice_mut(&mut self) -> &mut [Self::Element];
fn len(&self) -> usize;
fn is_empty(&self) -> bool {
self.len() == 0
}
}

pub enum CuHandleInner<T: Debug> {
Expand Down Expand Up @@ -55,14 +49,14 @@ where
{
pub fn slice(&self) -> &[T::Element] {
match self {
CuHandleInner::Pooled(pooled) => pooled.slice(),
CuHandleInner::Detached(detached) => detached.slice(),
CuHandleInner::Pooled(pooled) => pooled,
CuHandleInner::Detached(detached) => detached,
}
}
pub fn slice_mut(&mut self) -> &mut [T::Element] {
match self {
CuHandleInner::Pooled(pooled) => pooled.deref_mut().slice_mut(),
CuHandleInner::Detached(detached) => detached.slice_mut(),
CuHandleInner::Pooled(pooled) => pooled.deref_mut(),
CuHandleInner::Detached(detached) => detached,
}
}

Expand Down Expand Up @@ -115,14 +109,8 @@ where
fn encode<E: Encoder>(&self, encoder: &mut E) -> Result<(), EncodeError> {
let inner = self.0.lock().unwrap();
match inner.deref() {
CuHandleInner::Pooled(pooled) => {
let slice = pooled.slice();
slice.encode(encoder)
}
CuHandleInner::Detached(detached) => {
let slice = detached.slice();
slice.encode(encoder)
}
CuHandleInner::Pooled(pooled) => pooled.encode(encoder),
CuHandleInner::Detached(detached) => detached.encode(encoder),
}
}
}
Expand Down Expand Up @@ -190,28 +178,22 @@ impl<T: ArrayLike> CuPool<T> for CuHostMemoryPool<T> {
let to_handle = self.acquire().expect("No available buffers in the pool");

match from.0.lock().unwrap().deref() {
CuHandleInner::Detached(source) => {
let source_slice = source.slice();
match to_handle.0.lock().unwrap().deref_mut() {
CuHandleInner::Detached(destination) => {
destination.slice_mut().copy_from_slice(source_slice);
}
CuHandleInner::Pooled(destination) => {
destination.slice_mut().copy_from_slice(source_slice);
}
CuHandleInner::Detached(source) => match to_handle.0.lock().unwrap().deref_mut() {
CuHandleInner::Detached(destination) => {
destination.copy_from_slice(source);
}
}
CuHandleInner::Pooled(source) => {
let source_slice = source.slice();
match to_handle.0.lock().unwrap().deref_mut() {
CuHandleInner::Detached(destination) => {
destination.slice_mut().copy_from_slice(source_slice);
}
CuHandleInner::Pooled(destination) => {
destination.slice_mut().copy_from_slice(source_slice);
}
CuHandleInner::Pooled(destination) => {
destination.copy_from_slice(source);
}
}
},
CuHandleInner::Pooled(source) => match to_handle.0.lock().unwrap().deref_mut() {
CuHandleInner::Detached(destination) => {
destination.copy_from_slice(source);
}
CuHandleInner::Pooled(destination) => {
destination.copy_from_slice(source);
}
},
}
to_handle
}
Expand All @@ -233,18 +215,6 @@ where

impl<E: ElementType + 'static> ArrayLike for Vec<E> {
type Element = E;

fn slice(&self) -> &[Self::Element] {
self.as_slice()
}

fn slice_mut(&mut self) -> &mut [Self::Element] {
self.as_mut_slice()
}

fn len(&self) -> usize {
self.len()
}
}

#[cfg(feature = "cuda")]
Expand All @@ -253,12 +223,50 @@ mod cuda {
use cudarc::driver::{CudaDevice, CudaSlice, DeviceRepr, DeviceSlice, ValidAsZeroBits};
use std::sync::Arc;

#[derive(Debug)]
pub struct CudaSliceWrapper<E>(CudaSlice<E>);

impl<E> Deref for CudaSliceWrapper<E>
where
E: ElementType,
{
type Target = [E];

fn deref(&self) -> &Self::Target {
// Implement logic to return a slice
panic!("You need to copy data to host memory before accessing it.");
}
}

impl<E> DerefMut for CudaSliceWrapper<E>
where
E: ElementType,
{
fn deref_mut(&mut self) -> &mut Self::Target {
panic!("You need to copy data to host memory before accessing it.");
}
}

impl<E: ElementType> ArrayLike for CudaSliceWrapper<E> {
type Element = E;
}

impl<E> CudaSliceWrapper<E> {
pub fn as_cuda_slice(&self) -> &CudaSlice<E> {
&self.0
}

pub fn as_cuda_slice_mut(&mut self) -> &mut CudaSlice<E> {
&mut self.0
}
}

pub struct CuCudaPool<E>
where
E: ElementType + ValidAsZeroBits + DeviceRepr + Unpin,
{
device: Arc<CudaDevice>,
pool: Arc<Pool<CudaSlice<E>>>,
pool: Arc<Pool<CudaSliceWrapper<E>>>,
}

impl<E: ElementType + ValidAsZeroBits + DeviceRepr> CuCudaPool<E> {
Expand All @@ -270,17 +278,19 @@ mod cuda {
Self {
device: device.clone(),
pool: Arc::new(Pool::new(nb_buffers, || {
device
.alloc_zeros(nb_element_per_buffer)
.expect("Failed to allocate device memory")
CudaSliceWrapper(
device
.alloc_zeros(nb_element_per_buffer)
.expect("Failed to allocate device memory"),
)
})),
}
}

#[allow(dead_code)]
pub fn new_with<F>(device: Arc<CudaDevice>, nb_buffers: usize, f: F) -> Self
where
F: Fn() -> CudaSlice<E>,
F: Fn() -> CudaSliceWrapper<E>,
{
Self {
device: device.clone(),
Expand All @@ -289,24 +299,8 @@ mod cuda {
}
}

impl<E: ElementType> ArrayLike for CudaSlice<E> {
type Element = E;

fn slice(&self) -> &[Self::Element] {
panic!("You need to copy a handle to the host memory or a shared memory before accessing it")
}

fn slice_mut(&mut self) -> &mut [Self::Element] {
panic!("You need to copy a handle to the host memory or a shared memory before accessing it")
}

fn len(&self) -> usize {
DeviceSlice::<E>::len(self)
}
}

impl<E: ElementType + ValidAsZeroBits + DeviceRepr> CuPool<CudaSlice<E>> for CuCudaPool<E> {
fn acquire(&self) -> Option<CuHandle<CudaSlice<E>>> {
impl<E: ElementType + ValidAsZeroBits + DeviceRepr> CuPool<CudaSliceWrapper<E>> for CuCudaPool<E> {
fn acquire(&self) -> Option<CuHandle<CudaSliceWrapper<E>>> {
self.pool
.try_pull_owned()
.map(|x| CuHandle(Arc::new(Mutex::new(CuHandleInner::Pooled(x)))))
Expand All @@ -317,43 +311,37 @@ mod cuda {
}

/// Copy from host to device
fn copy_from<O>(&self, from: &mut CuHandle<O>) -> CuHandle<CudaSlice<E>>
fn copy_from<O>(&self, from_handle: &mut CuHandle<O>) -> CuHandle<CudaSliceWrapper<E>>
where
O: ArrayLike<Element = E>,
{
let to_handle = self.acquire().expect("No available buffers in the pool");

match from.0.lock().unwrap().deref() {
CuHandleInner::Detached(detached) => {
let from_slice = detached.slice();
match to_handle.0.lock().unwrap().deref_mut() {
CuHandleInner::Detached(detached) => {
self.device
.htod_sync_copy_into(from_slice, detached)
.expect("Failed to copy data to device");
}
CuHandleInner::Pooled(pooled) => {
self.device
.htod_sync_copy_into(from_slice, pooled.deref_mut())
.expect("Failed to copy data to device");
}
match from_handle.0.lock().unwrap().deref() {
CuHandleInner::Detached(from) => match to_handle.0.lock().unwrap().deref_mut() {
CuHandleInner::Detached(CudaSliceWrapper(to)) => {
self.device
.htod_sync_copy_into(from, to)
.expect("Failed to copy data to device");
}
}
CuHandleInner::Pooled(pooled) => {
let from_slice = pooled.slice();
match to_handle.0.lock().unwrap().deref_mut() {
CuHandleInner::Detached(detached) => {
self.device
.htod_sync_copy_into(from_slice, detached)
.expect("Failed to copy data to device");
}
CuHandleInner::Pooled(pooled) => {
self.device
.htod_sync_copy_into(from_slice, pooled.deref_mut())
.expect("Failed to copy data to device");
}
CuHandleInner::Pooled(to) => {
self.device
.htod_sync_copy_into(from, to.as_cuda_slice_mut())
.expect("Failed to copy data to device");
}
}
},
CuHandleInner::Pooled(from) => match to_handle.0.lock().unwrap().deref_mut() {
CuHandleInner::Detached(CudaSliceWrapper(to)) => {
self.device
.htod_sync_copy_into(from, to)
.expect("Failed to copy data to device");
}
CuHandleInner::Pooled(to) => {
self.device
.htod_sync_copy_into(from, to.as_cuda_slice_mut())
.expect("Failed to copy data to device");
}
},
}
to_handle
}
Expand All @@ -364,7 +352,7 @@ mod cuda {
E: ElementType + ValidAsZeroBits + DeviceRepr,
T: ArrayLike<Element = E>,
{
type O = CudaSlice<T::Element>;
type O = CudaSliceWrapper<T::Element>;

/// Copy from device to host
fn copy_into(
Expand All @@ -381,15 +369,12 @@ mod cuda {
match destination_handle.0.lock().unwrap().deref_mut() {
CuHandleInner::Pooled(ref mut destination) => {
self.device
.dtoh_sync_copy_into(
source.deref(),
destination.deref_mut().slice_mut(),
)
.dtoh_sync_copy_into(source.as_cuda_slice(), destination)
.expect("Failed to copy data to device");
}
CuHandleInner::Detached(ref mut destination) => {
self.device
.dtoh_sync_copy_into(source.deref(), destination.slice_mut())
.dtoh_sync_copy_into(source.as_cuda_slice(), destination)
.expect("Failed to copy data to device");
}
}
Expand All @@ -398,12 +383,12 @@ mod cuda {
match destination_handle.0.lock().unwrap().deref_mut() {
CuHandleInner::Pooled(ref mut destination) => {
self.device
.dtoh_sync_copy_into(source, destination.deref_mut().slice_mut())
.dtoh_sync_copy_into(source.as_cuda_slice(), destination)
.expect("Failed to copy data to device");
}
CuHandleInner::Detached(ref mut destination) => {
self.device
.dtoh_sync_copy_into(source, destination.slice_mut())
.dtoh_sync_copy_into(source.as_cuda_slice(), destination)
.expect("Failed to copy data to device");
}
}
Expand Down Expand Up @@ -462,26 +447,6 @@ impl<E: ElementType> Drop for AlignedBuffer<E> {
}
}

// Implement ArrayLike for fixed-size arrays
impl<T, const N: usize> ArrayLike for [T; N]
where
T: ElementType,
{
type Element = T;

fn slice(&self) -> &[Self::Element] {
self
}

fn slice_mut(&mut self) -> &mut [Self::Element] {
self
}

fn len(&self) -> usize {
N
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down

0 comments on commit ad6fc8b

Please sign in to comment.