Skip to main content

slop_alloc/
mem.rs

1use std::{rc::Rc, sync::Arc};
2
3use thiserror::Error;
4
5/// The [AllocError] error indicates an allocation failure that may be due to resource exhaustion
6/// or to something wrong when combining the given input arguments with this allocator.
7#[derive(Copy, Clone, PartialEq, Eq, Debug, Error)]
8#[error("allocation error")]
9pub struct AllocError;
10
11#[derive(Copy, Clone, PartialEq, Eq, Debug, Error)]
12#[error("copy error")]
13pub struct CopyError;
14
15/// The [CopyDirection] enum represents the direction of a memory copy operation.
16#[derive(Copy, Clone, PartialEq, Eq, Debug)]
17pub enum CopyDirection {
18    HostToDevice,
19    DeviceToHost,
20    DeviceToDevice,
21}
22
23/// A trait that defines memory operations for a device.
24pub trait DeviceMemory {
25    /// # Safety
26    unsafe fn copy_nonoverlapping(
27        &self,
28        src: *const u8,
29        dst: *mut u8,
30        size: usize,
31        direction: CopyDirection,
32    ) -> Result<(), CopyError>;
33
34    /// TODO
35    ///
36    /// # Safety
37    unsafe fn write_bytes(&self, dst: *mut u8, value: u8, size: usize) -> Result<(), CopyError>;
38}
39
40impl<T: DeviceMemory> DeviceMemory for &T {
41    #[inline]
42    unsafe fn copy_nonoverlapping(
43        &self,
44        src: *const u8,
45        dst: *mut u8,
46        size: usize,
47        direction: CopyDirection,
48    ) -> Result<(), CopyError> {
49        (**self).copy_nonoverlapping(src, dst, size, direction)
50    }
51
52    #[inline]
53    unsafe fn write_bytes(&self, dst: *mut u8, value: u8, size: usize) -> Result<(), CopyError> {
54        (**self).write_bytes(dst, value, size)
55    }
56}
57
58impl<T: DeviceMemory> DeviceMemory for Rc<T> {
59    #[inline]
60    unsafe fn copy_nonoverlapping(
61        &self,
62        src: *const u8,
63        dst: *mut u8,
64        size: usize,
65        direction: CopyDirection,
66    ) -> Result<(), CopyError> {
67        (**self).copy_nonoverlapping(src, dst, size, direction)
68    }
69
70    #[inline]
71    unsafe fn write_bytes(&self, dst: *mut u8, value: u8, size: usize) -> Result<(), CopyError> {
72        (**self).write_bytes(dst, value, size)
73    }
74}
75
76impl<T: DeviceMemory> DeviceMemory for Arc<T> {
77    #[inline]
78    unsafe fn copy_nonoverlapping(
79        &self,
80        src: *const u8,
81        dst: *mut u8,
82        size: usize,
83        direction: CopyDirection,
84    ) -> Result<(), CopyError> {
85        (**self).copy_nonoverlapping(src, dst, size, direction)
86    }
87
88    #[inline]
89    unsafe fn write_bytes(&self, dst: *mut u8, value: u8, size: usize) -> Result<(), CopyError> {
90        (**self).write_bytes(dst, value, size)
91    }
92}