Skip to main content

rustorch_core/
storage.rs

1use parking_lot::{RwLock, RwLockReadGuard, RwLockWriteGuard};
2use std::fmt;
3use std::sync::Arc;
4
5#[cfg(feature = "cuda")]
6use cudarc::driver::CudaSlice;
7#[cfg(feature = "vulkan_backend")]
8use vulkano::buffer::Subbuffer;
9#[cfg(feature = "wgpu_backend")]
10use wgpu;
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum Device {
14    Cpu,
15    Cuda(usize),
16    Metal(usize),
17    Wgpu(usize),
18    Vulkan(usize),
19}
20
21impl Device {
22    pub fn is_wgpu(&self) -> bool {
23        match self {
24            Device::Wgpu(_) => true,
25            _ => false,
26        }
27    }
28}
29
30#[derive(Clone)]
31enum StorageImpl {
32    Cpu(Arc<RwLock<Vec<f32>>>),
33    #[cfg(feature = "cuda")]
34    Cuda(Arc<CudaSlice<f32>>),
35    #[cfg(not(feature = "cuda"))]
36    #[allow(dead_code)]
37    CudaStub,
38    #[cfg(feature = "wgpu_backend")]
39    Wgpu(Arc<PooledBuffer>, usize),
40
41    #[cfg(feature = "vulkan_backend")]
42    Vulkan(Arc<Subbuffer<[f32]>>),
43}
44
45#[cfg(feature = "wgpu_backend")]
46pub struct PooledBuffer {
47    buffer: Option<wgpu::Buffer>,
48    size: u64,
49}
50
51#[cfg(feature = "wgpu_backend")]
52impl Drop for PooledBuffer {
53    fn drop(&mut self) {
54        if let Some(buf) = self.buffer.take() {
55            crate::backend::wgpu::get_memory_pool().return_buffer(buf, self.size);
56        }
57    }
58}
59
60#[derive(Clone)]
61pub struct Storage {
62    inner: StorageImpl,
63    device: Device,
64}
65
66impl Storage {
67    pub fn new(data: Vec<f32>) -> Self {
68        Self {
69            inner: StorageImpl::Cpu(Arc::new(RwLock::new(data))),
70            device: Device::Cpu,
71        }
72    }
73
74    #[cfg(feature = "cuda")]
75    pub fn new_cuda(data: CudaSlice<f32>, device_id: usize) -> Self {
76        Self {
77            inner: StorageImpl::Cuda(Arc::new(data)),
78            device: Device::Cuda(device_id),
79        }
80    }
81
82    #[cfg(feature = "wgpu_backend")]
83    pub fn new_wgpu(buffer: wgpu::Buffer, size: usize, _device_id: usize) -> Self {
84        let size_bytes = (size * std::mem::size_of::<f32>()) as u64;
85        let pooled = PooledBuffer {
86            buffer: Some(buffer),
87            size: size_bytes,
88        };
89        Self {
90            inner: StorageImpl::Wgpu(Arc::new(pooled), size),
91            device: Device::Wgpu(_device_id),
92        }
93    }
94
95    #[cfg(feature = "vulkan_backend")]
96    pub fn new_vulkan(buffer: Arc<Subbuffer<[f32]>>, device_id: usize) -> Self {
97        Self {
98            inner: StorageImpl::Vulkan(buffer),
99            device: Device::Vulkan(device_id),
100        }
101    }
102
103    #[cfg(feature = "wgpu_backend")]
104    pub fn wgpu_buffer(&self) -> Option<&wgpu::Buffer> {
105        match &self.inner {
106            StorageImpl::Wgpu(pooled, _) => pooled.buffer.as_ref(),
107            _ => None,
108        }
109    }
110
111    #[cfg(feature = "vulkan_backend")]
112    pub fn vulkan_buffer(&self) -> Option<&Subbuffer<[f32]>> {
113        match &self.inner {
114            StorageImpl::Vulkan(buffer) => Some(buffer),
115            _ => None,
116        }
117    }
118
119    pub fn from_slice(data: &[f32]) -> Self {
120        Self::new(data.to_vec())
121    }
122
123    pub fn zeros(size: usize) -> Self {
124        Self::new(vec![0.0; size])
125    }
126
127    pub fn data(&self) -> RwLockReadGuard<'_, Vec<f32>> {
128        match &self.inner {
129            StorageImpl::Cpu(data) => data.read(),
130            #[cfg(feature = "wgpu_backend")]
131            StorageImpl::Wgpu(_, _) => {
132                // Temporary workaround: panic with clear message
133                // Ideally, we should not access data() on WGPU tensor without to_cpu()
134                // But some code paths might do it implicitly.
135                // We CANNOT return a RwLockReadGuard here because we don't have the data locally locked.
136                // We must panic or refactor `data()` to return `Cow<[f32]>` or similar, but that breaks API.
137
138                println!(
139                    "CRITICAL ERROR: data() called on non-CPU storage. Device: {:?}",
140                    self.device
141                );
142                panic!("data() accessor only supported on CPU tensors. Use to_device() to move to CPU first.");
143            }
144            _ => {
145                println!(
146                    "CRITICAL ERROR: data() called on non-CPU storage. Device: {:?}",
147                    self.device
148                );
149                panic!("data() accessor only supported on CPU tensors. Use to_device() to move to CPU first.");
150            }
151        }
152    }
153
154    pub fn data_mut(&self) -> RwLockWriteGuard<'_, Vec<f32>> {
155        match &self.inner {
156            StorageImpl::Cpu(data) => data.write(),
157            _ => panic!("data_mut() accessor only supported on CPU tensors."),
158        }
159    }
160
161    pub fn as_slice(&self) -> RwLockReadGuard<'_, Vec<f32>> {
162        self.data()
163    }
164
165    pub fn len(&self) -> usize {
166        match &self.inner {
167            StorageImpl::Cpu(data) => data.read().len(),
168            #[cfg(feature = "cuda")]
169            StorageImpl::Cuda(data) => data.len(),
170            #[cfg(not(feature = "cuda"))]
171            #[allow(unused_variables)]
172            StorageImpl::CudaStub => 0,
173            #[cfg(feature = "wgpu_backend")]
174            StorageImpl::Wgpu(_, size) => *size,
175            #[cfg(feature = "vulkan_backend")]
176            StorageImpl::Vulkan(buf) => buf.len() as usize,
177        }
178    }
179
180    pub fn is_empty(&self) -> bool {
181        self.len() == 0
182    }
183
184    pub fn device(&self) -> Device {
185        self.device
186    }
187
188    pub fn to_device(&self, device: Device) -> Self {
189        if self.device == device {
190            return self.clone();
191        }
192
193        match (self.device, device) {
194            (Device::Cpu, Device::Cuda(_id)) => {
195                // Implement CPU -> CUDA transfer
196                #[cfg(feature = "cuda")]
197                {
198                    // Need a way to get CudaDevice instance.
199                    // Usually managed by a global context manager.
200                    // For now, panic or todo.
201                    todo!("Implement CPU -> CUDA transfer")
202                }
203                #[cfg(not(feature = "cuda"))]
204                panic!("CUDA feature not enabled")
205            }
206            (Device::Cuda(_), Device::Cpu) => {
207                // Implement CUDA -> CPU transfer
208                #[cfg(feature = "cuda")]
209                {
210                    // Read from GPU
211                    todo!("Implement CUDA -> CPU transfer")
212                }
213                #[cfg(not(feature = "cuda"))]
214                panic!("CUDA feature not enabled")
215            }
216            _ => todo!(
217                "Transfer between {:?} and {:?} not implemented",
218                self.device,
219                device
220            ),
221        }
222    }
223}
224
225impl fmt::Debug for Storage {
226    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
227        match &self.inner {
228            StorageImpl::Cpu(data) => {
229                let guard = data.read();
230                write!(f, "Storage({:?}, size={})", self.device, guard.len())
231            }
232            #[cfg(feature = "cuda")]
233            StorageImpl::Cuda(data) => {
234                write!(f, "CudaStorage({:?}, size={})", self.device, data.len())
235            }
236            #[cfg(not(feature = "cuda"))]
237            StorageImpl::CudaStub => {
238                write!(f, "CudaStorageStub({:?})", self.device)
239            }
240            #[cfg(feature = "wgpu_backend")]
241            StorageImpl::Wgpu(_, size) => {
242                write!(f, "WgpuStorage({:?}, size={})", self.device, size)
243            }
244            #[cfg(feature = "vulkan_backend")]
245            StorageImpl::Vulkan(buf) => {
246                write!(f, "VulkanStorage({:?}, size={})", self.device, buf.len())
247            }
248        }
249    }
250}