1use parking_lot::{RwLock, RwLockReadGuard, RwLockWriteGuard};
2use std::fmt;
3use std::sync::Arc;
4
5#[cfg(feature = "cuda")]
6use cudarc::driver::CudaSlice;
7#[cfg(feature = "vulkan_backend")]
8use vulkano::buffer::Subbuffer;
9#[cfg(feature = "wgpu_backend")]
10use wgpu;
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum Device {
14 Cpu,
15 Cuda(usize),
16 Metal(usize),
17 Wgpu(usize),
18 Vulkan(usize),
19}
20
21impl Device {
22 pub fn is_wgpu(&self) -> bool {
23 match self {
24 Device::Wgpu(_) => true,
25 _ => false,
26 }
27 }
28}
29
30#[derive(Clone)]
31enum StorageImpl {
32 Cpu(Arc<RwLock<Vec<f32>>>),
33 #[cfg(feature = "cuda")]
34 Cuda(Arc<CudaSlice<f32>>),
35 #[cfg(not(feature = "cuda"))]
36 #[allow(dead_code)]
37 CudaStub,
38 #[cfg(feature = "wgpu_backend")]
39 Wgpu(Arc<PooledBuffer>, usize),
40
41 #[cfg(feature = "vulkan_backend")]
42 Vulkan(Arc<Subbuffer<[f32]>>),
43}
44
45#[cfg(feature = "wgpu_backend")]
46pub struct PooledBuffer {
47 buffer: Option<wgpu::Buffer>,
48 size: u64,
49}
50
51#[cfg(feature = "wgpu_backend")]
52impl Drop for PooledBuffer {
53 fn drop(&mut self) {
54 if let Some(buf) = self.buffer.take() {
55 crate::backend::wgpu::get_memory_pool().return_buffer(buf, self.size);
56 }
57 }
58}
59
60#[derive(Clone)]
61pub struct Storage {
62 inner: StorageImpl,
63 device: Device,
64}
65
66impl Storage {
67 pub fn new(data: Vec<f32>) -> Self {
68 Self {
69 inner: StorageImpl::Cpu(Arc::new(RwLock::new(data))),
70 device: Device::Cpu,
71 }
72 }
73
74 #[cfg(feature = "cuda")]
75 pub fn new_cuda(data: CudaSlice<f32>, device_id: usize) -> Self {
76 Self {
77 inner: StorageImpl::Cuda(Arc::new(data)),
78 device: Device::Cuda(device_id),
79 }
80 }
81
82 #[cfg(feature = "wgpu_backend")]
83 pub fn new_wgpu(buffer: wgpu::Buffer, size: usize, _device_id: usize) -> Self {
84 let size_bytes = (size * std::mem::size_of::<f32>()) as u64;
85 let pooled = PooledBuffer {
86 buffer: Some(buffer),
87 size: size_bytes,
88 };
89 Self {
90 inner: StorageImpl::Wgpu(Arc::new(pooled), size),
91 device: Device::Wgpu(_device_id),
92 }
93 }
94
95 #[cfg(feature = "vulkan_backend")]
96 pub fn new_vulkan(buffer: Arc<Subbuffer<[f32]>>, device_id: usize) -> Self {
97 Self {
98 inner: StorageImpl::Vulkan(buffer),
99 device: Device::Vulkan(device_id),
100 }
101 }
102
103 #[cfg(feature = "wgpu_backend")]
104 pub fn wgpu_buffer(&self) -> Option<&wgpu::Buffer> {
105 match &self.inner {
106 StorageImpl::Wgpu(pooled, _) => pooled.buffer.as_ref(),
107 _ => None,
108 }
109 }
110
111 #[cfg(feature = "vulkan_backend")]
112 pub fn vulkan_buffer(&self) -> Option<&Subbuffer<[f32]>> {
113 match &self.inner {
114 StorageImpl::Vulkan(buffer) => Some(buffer),
115 _ => None,
116 }
117 }
118
119 pub fn from_slice(data: &[f32]) -> Self {
120 Self::new(data.to_vec())
121 }
122
123 pub fn zeros(size: usize) -> Self {
124 Self::new(vec![0.0; size])
125 }
126
127 pub fn data(&self) -> RwLockReadGuard<'_, Vec<f32>> {
128 match &self.inner {
129 StorageImpl::Cpu(data) => data.read(),
130 #[cfg(feature = "wgpu_backend")]
131 StorageImpl::Wgpu(_, _) => {
132 println!(
139 "CRITICAL ERROR: data() called on non-CPU storage. Device: {:?}",
140 self.device
141 );
142 panic!("data() accessor only supported on CPU tensors. Use to_device() to move to CPU first.");
143 }
144 _ => {
145 println!(
146 "CRITICAL ERROR: data() called on non-CPU storage. Device: {:?}",
147 self.device
148 );
149 panic!("data() accessor only supported on CPU tensors. Use to_device() to move to CPU first.");
150 }
151 }
152 }
153
154 pub fn data_mut(&self) -> RwLockWriteGuard<'_, Vec<f32>> {
155 match &self.inner {
156 StorageImpl::Cpu(data) => data.write(),
157 _ => panic!("data_mut() accessor only supported on CPU tensors."),
158 }
159 }
160
161 pub fn as_slice(&self) -> RwLockReadGuard<'_, Vec<f32>> {
162 self.data()
163 }
164
165 pub fn len(&self) -> usize {
166 match &self.inner {
167 StorageImpl::Cpu(data) => data.read().len(),
168 #[cfg(feature = "cuda")]
169 StorageImpl::Cuda(data) => data.len(),
170 #[cfg(not(feature = "cuda"))]
171 #[allow(unused_variables)]
172 StorageImpl::CudaStub => 0,
173 #[cfg(feature = "wgpu_backend")]
174 StorageImpl::Wgpu(_, size) => *size,
175 #[cfg(feature = "vulkan_backend")]
176 StorageImpl::Vulkan(buf) => buf.len() as usize,
177 }
178 }
179
180 pub fn is_empty(&self) -> bool {
181 self.len() == 0
182 }
183
184 pub fn device(&self) -> Device {
185 self.device
186 }
187
188 pub fn to_device(&self, device: Device) -> Self {
189 if self.device == device {
190 return self.clone();
191 }
192
193 match (self.device, device) {
194 (Device::Cpu, Device::Cuda(_id)) => {
195 #[cfg(feature = "cuda")]
197 {
198 todo!("Implement CPU -> CUDA transfer")
202 }
203 #[cfg(not(feature = "cuda"))]
204 panic!("CUDA feature not enabled")
205 }
206 (Device::Cuda(_), Device::Cpu) => {
207 #[cfg(feature = "cuda")]
209 {
210 todo!("Implement CUDA -> CPU transfer")
212 }
213 #[cfg(not(feature = "cuda"))]
214 panic!("CUDA feature not enabled")
215 }
216 _ => todo!(
217 "Transfer between {:?} and {:?} not implemented",
218 self.device,
219 device
220 ),
221 }
222 }
223}
224
225impl fmt::Debug for Storage {
226 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
227 match &self.inner {
228 StorageImpl::Cpu(data) => {
229 let guard = data.read();
230 write!(f, "Storage({:?}, size={})", self.device, guard.len())
231 }
232 #[cfg(feature = "cuda")]
233 StorageImpl::Cuda(data) => {
234 write!(f, "CudaStorage({:?}, size={})", self.device, data.len())
235 }
236 #[cfg(not(feature = "cuda"))]
237 StorageImpl::CudaStub => {
238 write!(f, "CudaStorageStub({:?})", self.device)
239 }
240 #[cfg(feature = "wgpu_backend")]
241 StorageImpl::Wgpu(_, size) => {
242 write!(f, "WgpuStorage({:?}, size={})", self.device, size)
243 }
244 #[cfg(feature = "vulkan_backend")]
245 StorageImpl::Vulkan(buf) => {
246 write!(f, "VulkanStorage({:?}, size={})", self.device, buf.len())
247 }
248 }
249 }
250}