tensorlogic_scirs_backend/
device.rs1use std::fmt;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
10pub enum DeviceType {
11 Cpu,
13
14 Cuda,
16
17 Metal,
19
20 Vulkan,
22
23 Rocm,
25}
26
27impl DeviceType {
28 pub fn is_gpu(&self) -> bool {
30 matches!(
31 self,
32 DeviceType::Cuda | DeviceType::Metal | DeviceType::Vulkan | DeviceType::Rocm
33 )
34 }
35
36 pub fn is_cpu(&self) -> bool {
38 matches!(self, DeviceType::Cpu)
39 }
40
41 pub fn name(&self) -> &'static str {
43 match self {
44 DeviceType::Cpu => "CPU",
45 DeviceType::Cuda => "CUDA",
46 DeviceType::Metal => "Metal",
47 DeviceType::Vulkan => "Vulkan",
48 DeviceType::Rocm => "ROCm",
49 }
50 }
51}
52
53impl fmt::Display for DeviceType {
54 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
55 write!(f, "{}", self.name())
56 }
57}
58
59#[derive(Debug, Clone, PartialEq, Eq, Hash)]
61pub struct Device {
62 pub device_type: DeviceType,
64
65 pub index: usize,
67}
68
69impl Device {
70 pub fn cpu() -> Self {
72 Self {
73 device_type: DeviceType::Cpu,
74 index: 0,
75 }
76 }
77
78 pub fn cuda(index: usize) -> Self {
80 Self {
81 device_type: DeviceType::Cuda,
82 index,
83 }
84 }
85
86 pub fn metal() -> Self {
88 Self {
89 device_type: DeviceType::Metal,
90 index: 0,
91 }
92 }
93
94 pub fn vulkan(index: usize) -> Self {
96 Self {
97 device_type: DeviceType::Vulkan,
98 index,
99 }
100 }
101
102 pub fn rocm(index: usize) -> Self {
104 Self {
105 device_type: DeviceType::Rocm,
106 index,
107 }
108 }
109
110 pub fn is_cpu(&self) -> bool {
112 self.device_type.is_cpu()
113 }
114
115 pub fn is_gpu(&self) -> bool {
117 self.device_type.is_gpu()
118 }
119
120 pub fn device_type(&self) -> DeviceType {
122 self.device_type
123 }
124
125 pub fn index(&self) -> usize {
127 self.index
128 }
129}
130
131impl Default for Device {
132 fn default() -> Self {
133 Self::cpu()
134 }
135}
136
137impl fmt::Display for Device {
138 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
139 if self.index == 0 && self.is_cpu() {
140 write!(f, "{}", self.device_type)
141 } else {
142 write!(f, "{}:{}", self.device_type, self.index)
143 }
144 }
145}
146
147#[derive(Debug, Clone)]
149pub struct DeviceManager {
150 available_devices: Vec<Device>,
152
153 default_device: Device,
155}
156
157impl DeviceManager {
158 pub fn new() -> Self {
163 #[cfg(test)] let available_devices = vec![Device::cpu()];
165
166 #[cfg(not(test))] let available_devices = {
168 let mut devices = vec![Device::cpu()];
169 let cuda_devices = crate::cuda_detect::detect_cuda_devices();
170 for cuda_info in cuda_devices {
171 devices.push(Device::cuda(cuda_info.index));
172 }
173 devices
174 };
175
176 Self {
177 available_devices: available_devices.clone(),
178 default_device: available_devices[0].clone(),
179 }
180 }
181
182 pub fn available_devices(&self) -> &[Device] {
184 &self.available_devices
185 }
186
187 pub fn default_device(&self) -> &Device {
189 &self.default_device
190 }
191
192 pub fn set_default_device(&mut self, device: Device) -> Result<(), DeviceError> {
194 if !self.available_devices.contains(&device) {
195 return Err(DeviceError::DeviceNotAvailable(device));
196 }
197 self.default_device = device;
198 Ok(())
199 }
200
201 pub fn is_available(&self, device: &Device) -> bool {
203 self.available_devices.contains(device)
204 }
205
206 pub fn get_device(&self, device_type: DeviceType, index: usize) -> Option<&Device> {
208 self.available_devices
209 .iter()
210 .find(|d| d.device_type == device_type && d.index == index)
211 }
212
213 pub fn count_devices(&self, device_type: DeviceType) -> usize {
215 self.available_devices
216 .iter()
217 .filter(|d| d.device_type == device_type)
218 .count()
219 }
220
221 pub fn has_gpu(&self) -> bool {
223 self.available_devices.iter().any(|d| d.is_gpu())
224 }
225}
226
227impl Default for DeviceManager {
228 fn default() -> Self {
229 Self::new()
230 }
231}
232
233#[derive(Debug, Clone, thiserror::Error)]
235pub enum DeviceError {
236 #[error("Device not available: {0}")]
238 DeviceNotAvailable(Device),
239
240 #[error("Device memory allocation failed: {0}")]
242 AllocationFailed(String),
243
244 #[error("Device synchronization failed: {0}")]
246 SyncFailed(String),
247
248 #[error("Unsupported operation on device {device}: {operation}")]
250 UnsupportedOperation { device: Device, operation: String },
251}
252
253#[cfg(test)]
254mod tests {
255 use super::*;
256
257 #[test]
258 fn test_device_type_properties() {
259 assert!(DeviceType::Cpu.is_cpu());
260 assert!(!DeviceType::Cpu.is_gpu());
261
262 assert!(DeviceType::Cuda.is_gpu());
263 assert!(!DeviceType::Cuda.is_cpu());
264
265 assert!(DeviceType::Metal.is_gpu());
266 assert!(DeviceType::Vulkan.is_gpu());
267 assert!(DeviceType::Rocm.is_gpu());
268 }
269
270 #[test]
271 fn test_device_type_display() {
272 assert_eq!(DeviceType::Cpu.to_string(), "CPU");
273 assert_eq!(DeviceType::Cuda.to_string(), "CUDA");
274 assert_eq!(DeviceType::Metal.to_string(), "Metal");
275 }
276
277 #[test]
278 fn test_device_creation() {
279 let cpu = Device::cpu();
280 assert!(cpu.is_cpu());
281 assert_eq!(cpu.index(), 0);
282
283 let cuda = Device::cuda(1);
284 assert!(cuda.is_gpu());
285 assert_eq!(cuda.index(), 1);
286 assert_eq!(cuda.device_type(), DeviceType::Cuda);
287 }
288
289 #[test]
290 fn test_device_default() {
291 let device = Device::default();
292 assert!(device.is_cpu());
293 assert_eq!(device.index(), 0);
294 }
295
296 #[test]
297 fn test_device_display() {
298 assert_eq!(Device::cpu().to_string(), "CPU");
299 assert_eq!(Device::cuda(0).to_string(), "CUDA:0");
300 assert_eq!(Device::cuda(1).to_string(), "CUDA:1");
301 assert_eq!(Device::metal().to_string(), "Metal:0");
302 }
303
304 #[test]
305 fn test_device_manager_creation() {
306 let manager = DeviceManager::new();
307 assert!(!manager.available_devices().is_empty());
308 assert!(manager.default_device().is_cpu());
309 }
310
311 #[test]
312 fn test_device_manager_queries() {
313 let manager = DeviceManager::new();
314
315 assert!(manager.is_available(&Device::cpu()));
317 assert_eq!(manager.count_devices(DeviceType::Cpu), 1);
318
319 assert_eq!(manager.default_device(), &Device::cpu());
321 }
322
323 #[test]
324 fn test_device_manager_set_default() {
325 let mut manager = DeviceManager::new();
326 let cpu = Device::cpu();
327
328 assert!(manager.set_default_device(cpu.clone()).is_ok());
330 assert_eq!(manager.default_device(), &cpu);
331
332 let cuda = Device::cuda(99);
334 assert!(manager.set_default_device(cuda).is_err());
335 }
336
337 #[test]
338 fn test_device_manager_get_device() {
339 let manager = DeviceManager::new();
340
341 let cpu = manager.get_device(DeviceType::Cpu, 0);
343 assert!(cpu.is_some());
344 assert_eq!(cpu.unwrap(), &Device::cpu());
345
346 let cuda = manager.get_device(DeviceType::Cuda, 0);
348 assert!(cuda.is_none());
349 }
350
351 #[test]
352 fn test_device_error_display() {
353 let err = DeviceError::DeviceNotAvailable(Device::cuda(0));
354 assert!(err.to_string().contains("not available"));
355
356 let err = DeviceError::AllocationFailed("out of memory".to_string());
357 assert!(err.to_string().contains("allocation failed"));
358 }
359}