tensorlogic_scirs_backend/
device.rs1use std::fmt;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
10pub enum DeviceType {
11 Cpu,
13
14 Cuda,
16
17 Metal,
19
20 Vulkan,
22
23 Rocm,
25}
26
27impl DeviceType {
28 pub fn is_gpu(&self) -> bool {
30 matches!(
31 self,
32 DeviceType::Cuda | DeviceType::Metal | DeviceType::Vulkan | DeviceType::Rocm
33 )
34 }
35
36 pub fn is_cpu(&self) -> bool {
38 matches!(self, DeviceType::Cpu)
39 }
40
41 pub fn name(&self) -> &'static str {
43 match self {
44 DeviceType::Cpu => "CPU",
45 DeviceType::Cuda => "CUDA",
46 DeviceType::Metal => "Metal",
47 DeviceType::Vulkan => "Vulkan",
48 DeviceType::Rocm => "ROCm",
49 }
50 }
51}
52
53impl fmt::Display for DeviceType {
54 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
55 write!(f, "{}", self.name())
56 }
57}
58
59#[derive(Debug, Clone, PartialEq, Eq, Hash)]
61pub struct Device {
62 pub device_type: DeviceType,
64
65 pub index: usize,
67}
68
69impl Device {
70 pub fn cpu() -> Self {
72 Self {
73 device_type: DeviceType::Cpu,
74 index: 0,
75 }
76 }
77
78 pub fn cuda(index: usize) -> Self {
80 Self {
81 device_type: DeviceType::Cuda,
82 index,
83 }
84 }
85
86 pub fn metal() -> Self {
88 Self {
89 device_type: DeviceType::Metal,
90 index: 0,
91 }
92 }
93
94 pub fn vulkan(index: usize) -> Self {
96 Self {
97 device_type: DeviceType::Vulkan,
98 index,
99 }
100 }
101
102 pub fn rocm(index: usize) -> Self {
104 Self {
105 device_type: DeviceType::Rocm,
106 index,
107 }
108 }
109
110 pub fn is_cpu(&self) -> bool {
112 self.device_type.is_cpu()
113 }
114
115 pub fn is_gpu(&self) -> bool {
117 self.device_type.is_gpu()
118 }
119
120 pub fn device_type(&self) -> DeviceType {
122 self.device_type
123 }
124
125 pub fn index(&self) -> usize {
127 self.index
128 }
129}
130
131impl Default for Device {
132 fn default() -> Self {
133 Self::cpu()
134 }
135}
136
137impl fmt::Display for Device {
138 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
139 if self.index == 0 && self.is_cpu() {
140 write!(f, "{}", self.device_type)
141 } else {
142 write!(f, "{}:{}", self.device_type, self.index)
143 }
144 }
145}
146
147#[derive(Debug, Clone)]
153pub struct SystemDeviceManager {
154 available_devices: Vec<Device>,
156
157 default_device: Device,
159}
160
161impl SystemDeviceManager {
162 pub fn new() -> Self {
167 #[cfg(test)] let available_devices = vec![Device::cpu()];
169
170 #[cfg(not(test))] let available_devices = {
172 let mut devices = vec![Device::cpu()];
173 let cuda_devices = crate::cuda_detect::detect_cuda_devices();
174 for cuda_info in cuda_devices {
175 devices.push(Device::cuda(cuda_info.index));
176 }
177 devices
178 };
179
180 Self {
181 available_devices: available_devices.clone(),
182 default_device: available_devices[0].clone(),
183 }
184 }
185
186 pub fn available_devices(&self) -> &[Device] {
188 &self.available_devices
189 }
190
191 pub fn default_device(&self) -> &Device {
193 &self.default_device
194 }
195
196 pub fn set_default_device(&mut self, device: Device) -> Result<(), DeviceError> {
198 if !self.available_devices.contains(&device) {
199 return Err(DeviceError::DeviceNotAvailable(device));
200 }
201 self.default_device = device;
202 Ok(())
203 }
204
205 pub fn is_available(&self, device: &Device) -> bool {
207 self.available_devices.contains(device)
208 }
209
210 pub fn get_device(&self, device_type: DeviceType, index: usize) -> Option<&Device> {
212 self.available_devices
213 .iter()
214 .find(|d| d.device_type == device_type && d.index == index)
215 }
216
217 pub fn count_devices(&self, device_type: DeviceType) -> usize {
219 self.available_devices
220 .iter()
221 .filter(|d| d.device_type == device_type)
222 .count()
223 }
224
225 pub fn has_gpu(&self) -> bool {
227 self.available_devices.iter().any(|d| d.is_gpu())
228 }
229}
230
231impl Default for SystemDeviceManager {
232 fn default() -> Self {
233 Self::new()
234 }
235}
236
237#[derive(Debug, Clone, thiserror::Error)]
239pub enum DeviceError {
240 #[error("Device not available: {0}")]
242 DeviceNotAvailable(Device),
243
244 #[error("Device memory allocation failed: {0}")]
246 AllocationFailed(String),
247
248 #[error("Device synchronization failed: {0}")]
250 SyncFailed(String),
251
252 #[error("Unsupported operation on device {device}: {operation}")]
254 UnsupportedOperation { device: Device, operation: String },
255}
256
257#[cfg(test)]
258mod tests {
259 use super::*;
260
261 #[test]
262 fn test_device_type_properties() {
263 assert!(DeviceType::Cpu.is_cpu());
264 assert!(!DeviceType::Cpu.is_gpu());
265
266 assert!(DeviceType::Cuda.is_gpu());
267 assert!(!DeviceType::Cuda.is_cpu());
268
269 assert!(DeviceType::Metal.is_gpu());
270 assert!(DeviceType::Vulkan.is_gpu());
271 assert!(DeviceType::Rocm.is_gpu());
272 }
273
274 #[test]
275 fn test_device_type_display() {
276 assert_eq!(DeviceType::Cpu.to_string(), "CPU");
277 assert_eq!(DeviceType::Cuda.to_string(), "CUDA");
278 assert_eq!(DeviceType::Metal.to_string(), "Metal");
279 }
280
281 #[test]
282 fn test_device_creation() {
283 let cpu = Device::cpu();
284 assert!(cpu.is_cpu());
285 assert_eq!(cpu.index(), 0);
286
287 let cuda = Device::cuda(1);
288 assert!(cuda.is_gpu());
289 assert_eq!(cuda.index(), 1);
290 assert_eq!(cuda.device_type(), DeviceType::Cuda);
291 }
292
293 #[test]
294 fn test_device_default() {
295 let device = Device::default();
296 assert!(device.is_cpu());
297 assert_eq!(device.index(), 0);
298 }
299
300 #[test]
301 fn test_device_display() {
302 assert_eq!(Device::cpu().to_string(), "CPU");
303 assert_eq!(Device::cuda(0).to_string(), "CUDA:0");
304 assert_eq!(Device::cuda(1).to_string(), "CUDA:1");
305 assert_eq!(Device::metal().to_string(), "Metal:0");
306 }
307
308 #[test]
309 fn test_device_manager_creation() {
310 let manager = SystemDeviceManager::new();
311 assert!(!manager.available_devices().is_empty());
312 assert!(manager.default_device().is_cpu());
313 }
314
315 #[test]
316 fn test_device_manager_queries() {
317 let manager = SystemDeviceManager::new();
318
319 assert!(manager.is_available(&Device::cpu()));
321 assert_eq!(manager.count_devices(DeviceType::Cpu), 1);
322
323 assert_eq!(manager.default_device(), &Device::cpu());
325 }
326
327 #[test]
328 fn test_device_manager_set_default() {
329 let mut manager = SystemDeviceManager::new();
330 let cpu = Device::cpu();
331
332 assert!(manager.set_default_device(cpu.clone()).is_ok());
334 assert_eq!(manager.default_device(), &cpu);
335
336 let cuda = Device::cuda(99);
338 assert!(manager.set_default_device(cuda).is_err());
339 }
340
341 #[test]
342 fn test_device_manager_get_device() {
343 let manager = SystemDeviceManager::new();
344
345 let cpu = manager.get_device(DeviceType::Cpu, 0);
347 assert!(cpu.is_some());
348 assert_eq!(cpu.expect("cpu device expected"), &Device::cpu());
349
350 let cuda = manager.get_device(DeviceType::Cuda, 0);
352 assert!(cuda.is_none());
353 }
354
355 #[test]
356 fn test_device_error_display() {
357 let err = DeviceError::DeviceNotAvailable(Device::cuda(0));
358 assert!(err.to_string().contains("not available"));
359
360 let err = DeviceError::AllocationFailed("out of memory".to_string());
361 assert!(err.to_string().contains("allocation failed"));
362 }
363}