1use crate::buffer::{generate_buffer_id, BufferHandle};
4use crate::error::{BackendError, BackendResult};
5use crate::{Buffer, BufferDescriptor, BufferUsage, Device};
6
7#[cfg(feature = "std")]
8use std::sync::{Arc, RwLock};
9
10#[cfg(not(feature = "std"))]
11use alloc::{sync::Arc, vec::Vec};
12
13#[cfg(not(feature = "std"))]
14use spin::RwLock;
15
16#[derive(Debug, Clone)]
18pub struct CpuBuffer {
19 data: Arc<RwLock<Vec<u8>>>,
20 size: usize,
21 usage: BufferUsage,
22}
23
24impl CpuBuffer {
25 pub fn new(size: usize, usage: BufferUsage) -> BackendResult<Self> {
27 if size > isize::MAX as usize {
29 return Err(torsh_core::error::TorshError::BackendError(format!(
30 "Buffer size {} is too large (exceeds maximum allowed size)",
31 size
32 )));
33 }
34
35 let data = match size {
36 0 => Vec::new(), size => {
38 match std::panic::catch_unwind(|| vec![0u8; size]) {
40 Ok(vec) => vec,
41 Err(_) => {
42 return Err(torsh_core::error::TorshError::BackendError(format!(
43 "Failed to allocate {} bytes for buffer",
44 size
45 )));
46 }
47 }
48 }
49 };
50
51 Ok(Self {
52 data: Arc::new(RwLock::new(data)),
53 size,
54 usage,
55 })
56 }
57
58 pub fn new_buffer(device: Device, descriptor: &BufferDescriptor) -> BackendResult<Buffer> {
60 let cpu_buffer = Self::new(descriptor.size, descriptor.usage)?;
61
62 let handle = BufferHandle::Generic {
65 handle: Box::new(cpu_buffer),
66 size: descriptor.size,
67 };
68
69 let buffer = Buffer::new(
70 generate_buffer_id(),
71 device,
72 descriptor.size,
73 descriptor.usage,
74 descriptor.clone(),
75 handle,
76 );
77
78 Ok(buffer)
79 }
80
81 pub fn from_data(data: Vec<u8>, usage: BufferUsage) -> Self {
83 let size = data.len();
84 Self {
85 data: Arc::new(RwLock::new(data)),
86 size,
87 usage,
88 }
89 }
90
91 pub fn size(&self) -> usize {
93 self.size
94 }
95
96 pub fn usage(&self) -> BufferUsage {
98 self.usage
99 }
100
101 pub fn read_bytes(&self, dst: &mut [u8], offset: usize) -> BackendResult<()> {
103 let data = self.data.read().map_err(|_| {
104 BackendError::AllocationError("Failed to acquire read lock".to_string())
105 })?;
106
107 if offset + dst.len() > data.len() {
108 return Err(BackendError::AllocationError(format!(
109 "Read bounds check failed: offset {} + size {} > buffer size {}",
110 offset,
111 dst.len(),
112 data.len()
113 )));
114 }
115
116 dst.copy_from_slice(&data[offset..offset + dst.len()]);
117 Ok(())
118 }
119
120 pub fn write_bytes(&self, src: &[u8], offset: usize) -> BackendResult<()> {
122 let mut data = self.data.write().map_err(|_| {
123 BackendError::AllocationError("Failed to acquire write lock".to_string())
124 })?;
125
126 if offset + src.len() > data.len() {
127 return Err(BackendError::AllocationError(format!(
128 "Write bounds check failed: offset {} + size {} > buffer size {}",
129 offset,
130 src.len(),
131 data.len()
132 )));
133 }
134
135 data[offset..offset + src.len()].copy_from_slice(src);
136 Ok(())
137 }
138
139 pub fn copy_to(
141 &self,
142 dst: &CpuBuffer,
143 src_offset: usize,
144 dst_offset: usize,
145 size: usize,
146 ) -> BackendResult<()> {
147 let src_data = self.data.read().map_err(|_| {
148 BackendError::AllocationError("Failed to acquire source read lock".to_string())
149 })?;
150
151 let mut dst_data = dst.data.write().map_err(|_| {
152 BackendError::AllocationError("Failed to acquire destination write lock".to_string())
153 })?;
154
155 if src_offset + size > src_data.len() {
156 return Err(BackendError::AllocationError(format!(
157 "Source bounds check failed: offset {} + size {} > buffer size {}",
158 src_offset,
159 size,
160 src_data.len()
161 )));
162 }
163
164 if dst_offset + size > dst_data.len() {
165 return Err(BackendError::AllocationError(format!(
166 "Destination bounds check failed: offset {} + size {} > buffer size {}",
167 dst_offset,
168 size,
169 dst_data.len()
170 )));
171 }
172
173 dst_data[dst_offset..dst_offset + size]
174 .copy_from_slice(&src_data[src_offset..src_offset + size]);
175
176 Ok(())
177 }
178
179 pub fn data(&self) -> Arc<RwLock<Vec<u8>>> {
181 self.data.clone()
182 }
183
184 pub fn map_read(&self) -> BackendResult<std::sync::RwLockReadGuard<'_, Vec<u8>>> {
186 self.data
187 .read()
188 .map_err(|_| BackendError::AllocationError("Failed to acquire read lock".to_string()))
189 }
190
191 pub fn map_write(&self) -> BackendResult<std::sync::RwLockWriteGuard<'_, Vec<u8>>> {
193 self.data
194 .write()
195 .map_err(|_| BackendError::AllocationError("Failed to acquire write lock".to_string()))
196 }
197}
198
199pub trait BufferCpuExt {
201 fn is_cpu(&self) -> bool;
202 fn as_cpu_ptr(&self) -> Option<*mut u8>;
203 fn as_cpu_buffer(&self) -> Option<&CpuBuffer>;
204}
205
206impl BufferCpuExt for Buffer {
207 fn is_cpu(&self) -> bool {
208 match &self.handle {
209 BufferHandle::Cpu { .. } => true,
210 BufferHandle::Generic { handle, .. } => {
211 handle.downcast_ref::<CpuBuffer>().is_some()
213 }
214 #[cfg(feature = "cuda")]
215 BufferHandle::Cuda { .. } => false,
216 #[cfg(all(feature = "metal", target_os = "macos", target_arch = "aarch64"))]
217 BufferHandle::Metal { .. } => false,
218 #[cfg(feature = "webgpu")]
219 BufferHandle::WebGpu { .. } => false,
220 }
221 }
222
223 fn as_cpu_ptr(&self) -> Option<*mut u8> {
224 match &self.handle {
225 BufferHandle::Cpu { ptr, .. } => Some(*ptr),
226 BufferHandle::Generic { handle, .. } => {
227 if let Some(cpu_buffer) = handle.downcast_ref::<CpuBuffer>() {
229 let data_guard = cpu_buffer.data.read().ok()?;
231 Some(data_guard.as_ptr() as *mut u8)
232 } else {
233 None
234 }
235 }
236 #[cfg(feature = "cuda")]
237 BufferHandle::Cuda { .. } => None,
238 #[cfg(all(feature = "metal", target_os = "macos", target_arch = "aarch64"))]
239 BufferHandle::Metal { .. } => None,
240 #[cfg(feature = "webgpu")]
241 BufferHandle::WebGpu { .. } => None,
242 }
243 }
244
245 fn as_cpu_buffer(&self) -> Option<&CpuBuffer> {
246 match &self.handle {
247 BufferHandle::Generic { handle, .. } => handle.downcast_ref::<CpuBuffer>(),
248 BufferHandle::Cpu { .. } => None, #[cfg(feature = "cuda")]
250 BufferHandle::Cuda { .. } => None,
251 #[cfg(all(feature = "metal", target_os = "macos", target_arch = "aarch64"))]
252 BufferHandle::Metal { .. } => None,
253 #[cfg(feature = "webgpu")]
254 BufferHandle::WebGpu { .. } => None,
255 }
256 }
257}
258
259impl CpuBuffer {
261 pub unsafe fn as_ptr(&self) -> *const u8 {
270 let data = self.data.read().expect("lock should not be poisoned");
271 data.as_ptr()
272 }
273
274 pub unsafe fn as_mut_ptr(&self) -> *mut u8 {
283 let mut data = self.data.write().expect("lock should not be poisoned");
284 data.as_mut_ptr()
285 }
286}
287
288#[cfg(test)]
289mod tests {
290 use super::*;
291
292 #[test]
293 fn test_cpu_buffer_creation() {
294 let buffer = CpuBuffer::new(1024, BufferUsage::STORAGE).unwrap();
295 assert_eq!(buffer.size(), 1024);
296 assert_eq!(buffer.usage(), BufferUsage::STORAGE);
297 }
298
299 #[test]
300 fn test_cpu_buffer_read_write() {
301 let buffer = CpuBuffer::new(256, BufferUsage::STORAGE).unwrap();
302
303 let write_data = vec![1, 2, 3, 4, 5];
304 buffer.write_bytes(&write_data, 10).unwrap();
305
306 let mut read_data = vec![0; 5];
307 buffer.read_bytes(&mut read_data, 10).unwrap();
308
309 assert_eq!(read_data, write_data);
310 }
311
312 #[test]
313 fn test_cpu_buffer_copy() {
314 let src_buffer = CpuBuffer::new(256, BufferUsage::STORAGE).unwrap();
315 let dst_buffer = CpuBuffer::new(256, BufferUsage::STORAGE).unwrap();
316
317 let test_data = vec![10, 20, 30, 40, 50];
318 src_buffer.write_bytes(&test_data, 0).unwrap();
319
320 src_buffer
321 .copy_to(&dst_buffer, 0, 0, test_data.len())
322 .unwrap();
323
324 let mut read_data = vec![0; test_data.len()];
325 dst_buffer.read_bytes(&mut read_data, 0).unwrap();
326
327 assert_eq!(read_data, test_data);
328 }
329
330 #[test]
331 fn test_buffer_bounds_checking() {
332 let buffer = CpuBuffer::new(10, BufferUsage::STORAGE).unwrap();
333
334 let mut read_data = vec![0; 5];
336 assert!(buffer.read_bytes(&mut read_data, 10).is_err()); let write_data = vec![1, 2, 3, 4, 5];
340 assert!(buffer.write_bytes(&write_data, 10).is_err()); }
342}