diff --git a/core/cu29_runtime/src/pool.rs b/core/cu29_runtime/src/pool.rs index 7a03fa919..9fcd466de 100644 --- a/core/cu29_runtime/src/pool.rs +++ b/core/cu29_runtime/src/pool.rs @@ -20,14 +20,8 @@ impl ElementType for T where { } -pub trait ArrayLike: Debug { +pub trait ArrayLike: Deref + DerefMut + Debug { type Element: ElementType; - fn slice(&self) -> &[Self::Element]; - fn slice_mut(&mut self) -> &mut [Self::Element]; - fn len(&self) -> usize; - fn is_empty(&self) -> bool { - self.len() == 0 - } } pub enum CuHandleInner { @@ -55,14 +49,14 @@ where { pub fn slice(&self) -> &[T::Element] { match self { - CuHandleInner::Pooled(pooled) => pooled.slice(), - CuHandleInner::Detached(detached) => detached.slice(), + CuHandleInner::Pooled(pooled) => pooled, + CuHandleInner::Detached(detached) => detached, } } pub fn slice_mut(&mut self) -> &mut [T::Element] { match self { - CuHandleInner::Pooled(pooled) => pooled.deref_mut().slice_mut(), - CuHandleInner::Detached(detached) => detached.slice_mut(), + CuHandleInner::Pooled(pooled) => pooled.deref_mut(), + CuHandleInner::Detached(detached) => detached, } } @@ -115,14 +109,8 @@ where fn encode(&self, encoder: &mut E) -> Result<(), EncodeError> { let inner = self.0.lock().unwrap(); match inner.deref() { - CuHandleInner::Pooled(pooled) => { - let slice = pooled.slice(); - slice.encode(encoder) - } - CuHandleInner::Detached(detached) => { - let slice = detached.slice(); - slice.encode(encoder) - } + CuHandleInner::Pooled(pooled) => pooled.encode(encoder), + CuHandleInner::Detached(detached) => detached.encode(encoder), } } } @@ -190,28 +178,22 @@ impl CuPool for CuHostMemoryPool { let to_handle = self.acquire().expect("No available buffers in the pool"); match from.0.lock().unwrap().deref() { - CuHandleInner::Detached(source) => { - let source_slice = source.slice(); - match to_handle.0.lock().unwrap().deref_mut() { - CuHandleInner::Detached(destination) => { - destination.slice_mut().copy_from_slice(source_slice); - } - CuHandleInner::Pooled(destination) => { - destination.slice_mut().copy_from_slice(source_slice); - } + CuHandleInner::Detached(source) => match to_handle.0.lock().unwrap().deref_mut() { + CuHandleInner::Detached(destination) => { + destination.copy_from_slice(source); } - } - CuHandleInner::Pooled(source) => { - let source_slice = source.slice(); - match to_handle.0.lock().unwrap().deref_mut() { - CuHandleInner::Detached(destination) => { - destination.slice_mut().copy_from_slice(source_slice); - } - CuHandleInner::Pooled(destination) => { - destination.slice_mut().copy_from_slice(source_slice); - } + CuHandleInner::Pooled(destination) => { + destination.copy_from_slice(source); } - } + }, + CuHandleInner::Pooled(source) => match to_handle.0.lock().unwrap().deref_mut() { + CuHandleInner::Detached(destination) => { + destination.copy_from_slice(source); + } + CuHandleInner::Pooled(destination) => { + destination.copy_from_slice(source); + } + }, } to_handle } @@ -233,18 +215,6 @@ where impl ArrayLike for Vec { type Element = E; - - fn slice(&self) -> &[Self::Element] { - self.as_slice() - } - - fn slice_mut(&mut self) -> &mut [Self::Element] { - self.as_mut_slice() - } - - fn len(&self) -> usize { - self.len() - } } #[cfg(feature = "cuda")] @@ -253,12 +223,50 @@ mod cuda { use cudarc::driver::{CudaDevice, CudaSlice, DeviceRepr, DeviceSlice, ValidAsZeroBits}; use std::sync::Arc; + #[derive(Debug)] + pub struct CudaSliceWrapper(CudaSlice); + + impl Deref for CudaSliceWrapper + where + E: ElementType, + { + type Target = [E]; + + fn deref(&self) -> &Self::Target { + // Implement logic to return a slice + panic!("You need to copy data to host memory before accessing it."); + } + } + + impl DerefMut for CudaSliceWrapper + where + E: ElementType, + { + fn deref_mut(&mut self) -> &mut Self::Target { + panic!("You need to copy data to host memory before accessing it."); + } + } + + impl ArrayLike for CudaSliceWrapper { + type Element = E; + } + + impl CudaSliceWrapper { + pub fn as_cuda_slice(&self) -> &CudaSlice { + &self.0 + } + + pub fn as_cuda_slice_mut(&mut self) -> &mut CudaSlice { + &mut self.0 + } + } + pub struct CuCudaPool where E: ElementType + ValidAsZeroBits + DeviceRepr + Unpin, { device: Arc, - pool: Arc>>, + pool: Arc>>, } impl CuCudaPool { @@ -270,9 +278,11 @@ mod cuda { Self { device: device.clone(), pool: Arc::new(Pool::new(nb_buffers, || { - device - .alloc_zeros(nb_element_per_buffer) - .expect("Failed to allocate device memory") + CudaSliceWrapper( + device + .alloc_zeros(nb_element_per_buffer) + .expect("Failed to allocate device memory"), + ) })), } } @@ -280,7 +290,7 @@ mod cuda { #[allow(dead_code)] pub fn new_with(device: Arc, nb_buffers: usize, f: F) -> Self where - F: Fn() -> CudaSlice, + F: Fn() -> CudaSliceWrapper, { Self { device: device.clone(), @@ -289,24 +299,8 @@ mod cuda { } } - impl ArrayLike for CudaSlice { - type Element = E; - - fn slice(&self) -> &[Self::Element] { - panic!("You need to copy a handle to the host memory or a shared memory before accessing it") - } - - fn slice_mut(&mut self) -> &mut [Self::Element] { - panic!("You need to copy a handle to the host memory or a shared memory before accessing it") - } - - fn len(&self) -> usize { - DeviceSlice::::len(self) - } - } - - impl CuPool> for CuCudaPool { - fn acquire(&self) -> Option>> { + impl CuPool> for CuCudaPool { + fn acquire(&self) -> Option>> { self.pool .try_pull_owned() .map(|x| CuHandle(Arc::new(Mutex::new(CuHandleInner::Pooled(x))))) @@ -317,43 +311,37 @@ mod cuda { } /// Copy from host to device - fn copy_from(&self, from: &mut CuHandle) -> CuHandle> + fn copy_from(&self, from_handle: &mut CuHandle) -> CuHandle> where O: ArrayLike, { let to_handle = self.acquire().expect("No available buffers in the pool"); - match from.0.lock().unwrap().deref() { - CuHandleInner::Detached(detached) => { - let from_slice = detached.slice(); - match to_handle.0.lock().unwrap().deref_mut() { - CuHandleInner::Detached(detached) => { - self.device - .htod_sync_copy_into(from_slice, detached) - .expect("Failed to copy data to device"); - } - CuHandleInner::Pooled(pooled) => { - self.device - .htod_sync_copy_into(from_slice, pooled.deref_mut()) - .expect("Failed to copy data to device"); - } + match from_handle.0.lock().unwrap().deref() { + CuHandleInner::Detached(from) => match to_handle.0.lock().unwrap().deref_mut() { + CuHandleInner::Detached(CudaSliceWrapper(to)) => { + self.device + .htod_sync_copy_into(from, to) + .expect("Failed to copy data to device"); } - } - CuHandleInner::Pooled(pooled) => { - let from_slice = pooled.slice(); - match to_handle.0.lock().unwrap().deref_mut() { - CuHandleInner::Detached(detached) => { - self.device - .htod_sync_copy_into(from_slice, detached) - .expect("Failed to copy data to device"); - } - CuHandleInner::Pooled(pooled) => { - self.device - .htod_sync_copy_into(from_slice, pooled.deref_mut()) - .expect("Failed to copy data to device"); - } + CuHandleInner::Pooled(to) => { + self.device + .htod_sync_copy_into(from, to.as_cuda_slice_mut()) + .expect("Failed to copy data to device"); } - } + }, + CuHandleInner::Pooled(from) => match to_handle.0.lock().unwrap().deref_mut() { + CuHandleInner::Detached(CudaSliceWrapper(to)) => { + self.device + .htod_sync_copy_into(from, to) + .expect("Failed to copy data to device"); + } + CuHandleInner::Pooled(to) => { + self.device + .htod_sync_copy_into(from, to.as_cuda_slice_mut()) + .expect("Failed to copy data to device"); + } + }, } to_handle } @@ -364,7 +352,7 @@ mod cuda { E: ElementType + ValidAsZeroBits + DeviceRepr, T: ArrayLike, { - type O = CudaSlice; + type O = CudaSliceWrapper; /// Copy from device to host fn copy_into( @@ -381,15 +369,12 @@ mod cuda { match destination_handle.0.lock().unwrap().deref_mut() { CuHandleInner::Pooled(ref mut destination) => { self.device - .dtoh_sync_copy_into( - source.deref(), - destination.deref_mut().slice_mut(), - ) + .dtoh_sync_copy_into(source.as_cuda_slice(), destination) .expect("Failed to copy data to device"); } CuHandleInner::Detached(ref mut destination) => { self.device - .dtoh_sync_copy_into(source.deref(), destination.slice_mut()) + .dtoh_sync_copy_into(source.as_cuda_slice(), destination) .expect("Failed to copy data to device"); } } @@ -398,12 +383,12 @@ mod cuda { match destination_handle.0.lock().unwrap().deref_mut() { CuHandleInner::Pooled(ref mut destination) => { self.device - .dtoh_sync_copy_into(source, destination.deref_mut().slice_mut()) + .dtoh_sync_copy_into(source.as_cuda_slice(), destination) .expect("Failed to copy data to device"); } CuHandleInner::Detached(ref mut destination) => { self.device - .dtoh_sync_copy_into(source, destination.slice_mut()) + .dtoh_sync_copy_into(source.as_cuda_slice(), destination) .expect("Failed to copy data to device"); } } @@ -462,26 +447,6 @@ impl Drop for AlignedBuffer { } } -// Implement ArrayLike for fixed-size arrays -impl ArrayLike for [T; N] -where - T: ElementType, -{ - type Element = T; - - fn slice(&self) -> &[Self::Element] { - self - } - - fn slice_mut(&mut self) -> &mut [Self::Element] { - self - } - - fn len(&self) -> usize { - N - } -} - #[cfg(test)] mod tests { use super::*;