Skip to main content

torsh_backend/cpu/
buffer.rs

1//! CPU Buffer Implementation
2
3use crate::buffer::{generate_buffer_id, BufferHandle};
4use crate::error::{BackendError, BackendResult};
5use crate::{Buffer, BufferDescriptor, BufferUsage, Device};
6
7#[cfg(feature = "std")]
8use std::sync::{Arc, RwLock};
9
10#[cfg(not(feature = "std"))]
11use alloc::{sync::Arc, vec::Vec};
12
13#[cfg(not(feature = "std"))]
14use spin::RwLock;
15
16/// CPU buffer implementation using system memory
17#[derive(Debug, Clone)]
18pub struct CpuBuffer {
19    data: Arc<RwLock<Vec<u8>>>,
20    size: usize,
21    usage: BufferUsage,
22}
23
24impl CpuBuffer {
25    /// Create a new CPU buffer
26    pub fn new(size: usize, usage: BufferUsage) -> BackendResult<Self> {
27        // Check for reasonable size limits to avoid capacity overflow
28        if size > isize::MAX as usize {
29            return Err(torsh_core::error::TorshError::BackendError(format!(
30                "Buffer size {} is too large (exceeds maximum allowed size)",
31                size
32            )));
33        }
34
35        let data = match size {
36            0 => Vec::new(), // Handle zero-size case
37            size => {
38                // Try to allocate the vector, catching any allocation errors
39                match std::panic::catch_unwind(|| vec![0u8; size]) {
40                    Ok(vec) => vec,
41                    Err(_) => {
42                        return Err(torsh_core::error::TorshError::BackendError(format!(
43                            "Failed to allocate {} bytes for buffer",
44                            size
45                        )));
46                    }
47                }
48            }
49        };
50
51        Ok(Self {
52            data: Arc::new(RwLock::new(data)),
53            size,
54            usage,
55        })
56    }
57
58    /// Create a CPU buffer and return an abstract Buffer
59    pub fn new_buffer(device: Device, descriptor: &BufferDescriptor) -> BackendResult<Buffer> {
60        let cpu_buffer = Self::new(descriptor.size, descriptor.usage)?;
61
62        // Store the CpuBuffer in the BufferHandle using Generic variant
63        // This avoids the dangling pointer issue
64        let handle = BufferHandle::Generic {
65            handle: Box::new(cpu_buffer),
66            size: descriptor.size,
67        };
68
69        let buffer = Buffer::new(
70            generate_buffer_id(),
71            device,
72            descriptor.size,
73            descriptor.usage,
74            descriptor.clone(),
75            handle,
76        );
77
78        Ok(buffer)
79    }
80
81    /// Create a CPU buffer from existing data
82    pub fn from_data(data: Vec<u8>, usage: BufferUsage) -> Self {
83        let size = data.len();
84        Self {
85            data: Arc::new(RwLock::new(data)),
86            size,
87            usage,
88        }
89    }
90
91    /// Get the buffer size in bytes
92    pub fn size(&self) -> usize {
93        self.size
94    }
95
96    /// Get the buffer usage flags
97    pub fn usage(&self) -> BufferUsage {
98        self.usage
99    }
100
101    /// Read data from the buffer
102    pub fn read_bytes(&self, dst: &mut [u8], offset: usize) -> BackendResult<()> {
103        let data = self.data.read().map_err(|_| {
104            BackendError::AllocationError("Failed to acquire read lock".to_string())
105        })?;
106
107        if offset + dst.len() > data.len() {
108            return Err(BackendError::AllocationError(format!(
109                "Read bounds check failed: offset {} + size {} > buffer size {}",
110                offset,
111                dst.len(),
112                data.len()
113            )));
114        }
115
116        dst.copy_from_slice(&data[offset..offset + dst.len()]);
117        Ok(())
118    }
119
120    /// Write data to the buffer
121    pub fn write_bytes(&self, src: &[u8], offset: usize) -> BackendResult<()> {
122        let mut data = self.data.write().map_err(|_| {
123            BackendError::AllocationError("Failed to acquire write lock".to_string())
124        })?;
125
126        if offset + src.len() > data.len() {
127            return Err(BackendError::AllocationError(format!(
128                "Write bounds check failed: offset {} + size {} > buffer size {}",
129                offset,
130                src.len(),
131                data.len()
132            )));
133        }
134
135        data[offset..offset + src.len()].copy_from_slice(src);
136        Ok(())
137    }
138
139    /// Copy data from another CPU buffer
140    pub fn copy_to(
141        &self,
142        dst: &CpuBuffer,
143        src_offset: usize,
144        dst_offset: usize,
145        size: usize,
146    ) -> BackendResult<()> {
147        let src_data = self.data.read().map_err(|_| {
148            BackendError::AllocationError("Failed to acquire source read lock".to_string())
149        })?;
150
151        let mut dst_data = dst.data.write().map_err(|_| {
152            BackendError::AllocationError("Failed to acquire destination write lock".to_string())
153        })?;
154
155        if src_offset + size > src_data.len() {
156            return Err(BackendError::AllocationError(format!(
157                "Source bounds check failed: offset {} + size {} > buffer size {}",
158                src_offset,
159                size,
160                src_data.len()
161            )));
162        }
163
164        if dst_offset + size > dst_data.len() {
165            return Err(BackendError::AllocationError(format!(
166                "Destination bounds check failed: offset {} + size {} > buffer size {}",
167                dst_offset,
168                size,
169                dst_data.len()
170            )));
171        }
172
173        dst_data[dst_offset..dst_offset + size]
174            .copy_from_slice(&src_data[src_offset..src_offset + size]);
175
176        Ok(())
177    }
178
179    /// Get a reference to the underlying data (for zero-copy operations)
180    pub fn data(&self) -> Arc<RwLock<Vec<u8>>> {
181        self.data.clone()
182    }
183
184    /// Map the buffer for reading (returns a read guard)
185    pub fn map_read(&self) -> BackendResult<std::sync::RwLockReadGuard<'_, Vec<u8>>> {
186        self.data
187            .read()
188            .map_err(|_| BackendError::AllocationError("Failed to acquire read lock".to_string()))
189    }
190
191    /// Map the buffer for writing (returns a write guard)
192    pub fn map_write(&self) -> BackendResult<std::sync::RwLockWriteGuard<'_, Vec<u8>>> {
193        self.data
194            .write()
195            .map_err(|_| BackendError::AllocationError("Failed to acquire write lock".to_string()))
196    }
197}
198
199// Extension trait for Buffer to work with CPU buffers
200pub trait BufferCpuExt {
201    fn is_cpu(&self) -> bool;
202    fn as_cpu_ptr(&self) -> Option<*mut u8>;
203    fn as_cpu_buffer(&self) -> Option<&CpuBuffer>;
204}
205
206impl BufferCpuExt for Buffer {
207    fn is_cpu(&self) -> bool {
208        match &self.handle {
209            BufferHandle::Cpu { .. } => true,
210            BufferHandle::Generic { handle, .. } => {
211                // Check if the generic handle contains a CpuBuffer
212                handle.downcast_ref::<CpuBuffer>().is_some()
213            }
214            #[cfg(feature = "cuda")]
215            BufferHandle::Cuda { .. } => false,
216            #[cfg(all(feature = "metal", target_os = "macos", target_arch = "aarch64"))]
217            BufferHandle::Metal { .. } => false,
218            #[cfg(feature = "webgpu")]
219            BufferHandle::WebGpu { .. } => false,
220        }
221    }
222
223    fn as_cpu_ptr(&self) -> Option<*mut u8> {
224        match &self.handle {
225            BufferHandle::Cpu { ptr, .. } => Some(*ptr),
226            BufferHandle::Generic { handle, .. } => {
227                // Safely extract pointer from CpuBuffer
228                if let Some(cpu_buffer) = handle.downcast_ref::<CpuBuffer>() {
229                    // Get pointer to the underlying data safely
230                    let data_guard = cpu_buffer.data.read().ok()?;
231                    Some(data_guard.as_ptr() as *mut u8)
232                } else {
233                    None
234                }
235            }
236            #[cfg(feature = "cuda")]
237            BufferHandle::Cuda { .. } => None,
238            #[cfg(all(feature = "metal", target_os = "macos", target_arch = "aarch64"))]
239            BufferHandle::Metal { .. } => None,
240            #[cfg(feature = "webgpu")]
241            BufferHandle::WebGpu { .. } => None,
242        }
243    }
244
245    fn as_cpu_buffer(&self) -> Option<&CpuBuffer> {
246        match &self.handle {
247            BufferHandle::Generic { handle, .. } => handle.downcast_ref::<CpuBuffer>(),
248            BufferHandle::Cpu { .. } => None, // Legacy CPU buffers don't have CpuBuffer reference
249            #[cfg(feature = "cuda")]
250            BufferHandle::Cuda { .. } => None,
251            #[cfg(all(feature = "metal", target_os = "macos", target_arch = "aarch64"))]
252            BufferHandle::Metal { .. } => None,
253            #[cfg(feature = "webgpu")]
254            BufferHandle::WebGpu { .. } => None,
255        }
256    }
257}
258
259// Unsafe operations for performance-critical code
260impl CpuBuffer {
261    /// Get a raw pointer to the buffer data (unsafe)
262    ///
263    /// # Safety
264    ///
265    /// The caller must ensure:
266    /// - The returned pointer is not used after the buffer is dropped
267    /// - No mutable references to the buffer exist when using this pointer
268    /// - The buffer is not resized while using this pointer
269    pub unsafe fn as_ptr(&self) -> *const u8 {
270        let data = self.data.read().expect("lock should not be poisoned");
271        data.as_ptr()
272    }
273
274    /// Get a raw mutable pointer to the buffer data (unsafe)
275    ///
276    /// # Safety
277    ///
278    /// The caller must ensure:
279    /// - The returned pointer is not used after the buffer is dropped
280    /// - No other references to the buffer exist when using this pointer
281    /// - The buffer is not resized while using this pointer
282    pub unsafe fn as_mut_ptr(&self) -> *mut u8 {
283        let mut data = self.data.write().expect("lock should not be poisoned");
284        data.as_mut_ptr()
285    }
286}
287
288#[cfg(test)]
289mod tests {
290    use super::*;
291
292    #[test]
293    fn test_cpu_buffer_creation() {
294        let buffer = CpuBuffer::new(1024, BufferUsage::STORAGE).unwrap();
295        assert_eq!(buffer.size(), 1024);
296        assert_eq!(buffer.usage(), BufferUsage::STORAGE);
297    }
298
299    #[test]
300    fn test_cpu_buffer_read_write() {
301        let buffer = CpuBuffer::new(256, BufferUsage::STORAGE).unwrap();
302
303        let write_data = vec![1, 2, 3, 4, 5];
304        buffer.write_bytes(&write_data, 10).unwrap();
305
306        let mut read_data = vec![0; 5];
307        buffer.read_bytes(&mut read_data, 10).unwrap();
308
309        assert_eq!(read_data, write_data);
310    }
311
312    #[test]
313    fn test_cpu_buffer_copy() {
314        let src_buffer = CpuBuffer::new(256, BufferUsage::STORAGE).unwrap();
315        let dst_buffer = CpuBuffer::new(256, BufferUsage::STORAGE).unwrap();
316
317        let test_data = vec![10, 20, 30, 40, 50];
318        src_buffer.write_bytes(&test_data, 0).unwrap();
319
320        src_buffer
321            .copy_to(&dst_buffer, 0, 0, test_data.len())
322            .unwrap();
323
324        let mut read_data = vec![0; test_data.len()];
325        dst_buffer.read_bytes(&mut read_data, 0).unwrap();
326
327        assert_eq!(read_data, test_data);
328    }
329
330    #[test]
331    fn test_buffer_bounds_checking() {
332        let buffer = CpuBuffer::new(10, BufferUsage::STORAGE).unwrap();
333
334        // Test read bounds
335        let mut read_data = vec![0; 5];
336        assert!(buffer.read_bytes(&mut read_data, 10).is_err()); // Out of bounds
337
338        // Test write bounds
339        let write_data = vec![1, 2, 3, 4, 5];
340        assert!(buffer.write_bytes(&write_data, 10).is_err()); // Out of bounds
341    }
342}