Skip to main content

tensorlogic_scirs_backend/
device.rs

1//! Device management for tensor computations.
2//!
3//! This module provides abstractions for managing compute devices
4//! (CPU, GPU, etc.) and tensor placement.
5
6use std::fmt;
7
8/// Compute device type.
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
10pub enum DeviceType {
11    /// CPU device (default)
12    Cpu,
13
14    /// CUDA GPU device
15    Cuda,
16
17    /// Metal GPU device (Apple)
18    Metal,
19
20    /// Vulkan compute device
21    Vulkan,
22
23    /// ROCm GPU device (AMD)
24    Rocm,
25}
26
27impl DeviceType {
28    /// Returns true if this is a GPU device.
29    pub fn is_gpu(&self) -> bool {
30        matches!(
31            self,
32            DeviceType::Cuda | DeviceType::Metal | DeviceType::Vulkan | DeviceType::Rocm
33        )
34    }
35
36    /// Returns true if this is a CPU device.
37    pub fn is_cpu(&self) -> bool {
38        matches!(self, DeviceType::Cpu)
39    }
40
41    /// Returns the name of this device type.
42    pub fn name(&self) -> &'static str {
43        match self {
44            DeviceType::Cpu => "CPU",
45            DeviceType::Cuda => "CUDA",
46            DeviceType::Metal => "Metal",
47            DeviceType::Vulkan => "Vulkan",
48            DeviceType::Rocm => "ROCm",
49        }
50    }
51}
52
53impl fmt::Display for DeviceType {
54    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
55        write!(f, "{}", self.name())
56    }
57}
58
59/// A specific compute device.
60#[derive(Debug, Clone, PartialEq, Eq, Hash)]
61pub struct Device {
62    /// Device type
63    pub device_type: DeviceType,
64
65    /// Device index (for multi-GPU systems)
66    pub index: usize,
67}
68
69impl Device {
70    /// Create a CPU device.
71    pub fn cpu() -> Self {
72        Self {
73            device_type: DeviceType::Cpu,
74            index: 0,
75        }
76    }
77
78    /// Create a CUDA device with the given index.
79    pub fn cuda(index: usize) -> Self {
80        Self {
81            device_type: DeviceType::Cuda,
82            index,
83        }
84    }
85
86    /// Create a Metal device.
87    pub fn metal() -> Self {
88        Self {
89            device_type: DeviceType::Metal,
90            index: 0,
91        }
92    }
93
94    /// Create a Vulkan device with the given index.
95    pub fn vulkan(index: usize) -> Self {
96        Self {
97            device_type: DeviceType::Vulkan,
98            index,
99        }
100    }
101
102    /// Create a ROCm device with the given index.
103    pub fn rocm(index: usize) -> Self {
104        Self {
105            device_type: DeviceType::Rocm,
106            index,
107        }
108    }
109
110    /// Returns true if this is a CPU device.
111    pub fn is_cpu(&self) -> bool {
112        self.device_type.is_cpu()
113    }
114
115    /// Returns true if this is a GPU device.
116    pub fn is_gpu(&self) -> bool {
117        self.device_type.is_gpu()
118    }
119
120    /// Returns the device type.
121    pub fn device_type(&self) -> DeviceType {
122        self.device_type
123    }
124
125    /// Returns the device index.
126    pub fn index(&self) -> usize {
127        self.index
128    }
129}
130
131impl Default for Device {
132    fn default() -> Self {
133        Self::cpu()
134    }
135}
136
137impl fmt::Display for Device {
138    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
139        if self.index == 0 && self.is_cpu() {
140            write!(f, "{}", self.device_type)
141        } else {
142            write!(f, "{}:{}", self.device_type, self.index)
143        }
144    }
145}
146
147/// System device manager for querying available hardware devices.
148///
149/// This type enumerates and tracks all compute devices available on the current
150/// system (CPU, CUDA GPUs, etc.).  For operation-level device *selection* based
151/// on tensor shape and op kind, see [`crate::device_manager::DeviceManager`].
152#[derive(Debug, Clone)]
153pub struct SystemDeviceManager {
154    /// List of available devices
155    available_devices: Vec<Device>,
156
157    /// Default device
158    default_device: Device,
159}
160
161impl SystemDeviceManager {
162    /// Create a new system device manager.
163    ///
164    /// This queries the system for available devices, including CUDA GPUs
165    /// if available via nvidia-smi.
166    pub fn new() -> Self {
167        #[cfg(test)] // In tests, only CPU is available
168        let available_devices = vec![Device::cpu()];
169
170        #[cfg(not(test))] // In production, detect CUDA devices
171        let available_devices = {
172            let mut devices = vec![Device::cpu()];
173            let cuda_devices = crate::cuda_detect::detect_cuda_devices();
174            for cuda_info in cuda_devices {
175                devices.push(Device::cuda(cuda_info.index));
176            }
177            devices
178        };
179
180        Self {
181            available_devices: available_devices.clone(),
182            default_device: available_devices[0].clone(),
183        }
184    }
185
186    /// Get the list of available devices.
187    pub fn available_devices(&self) -> &[Device] {
188        &self.available_devices
189    }
190
191    /// Get the default device.
192    pub fn default_device(&self) -> &Device {
193        &self.default_device
194    }
195
196    /// Set the default device.
197    pub fn set_default_device(&mut self, device: Device) -> Result<(), DeviceError> {
198        if !self.available_devices.contains(&device) {
199            return Err(DeviceError::DeviceNotAvailable(device));
200        }
201        self.default_device = device;
202        Ok(())
203    }
204
205    /// Check if a device is available.
206    pub fn is_available(&self, device: &Device) -> bool {
207        self.available_devices.contains(device)
208    }
209
210    /// Get a device by type and index.
211    pub fn get_device(&self, device_type: DeviceType, index: usize) -> Option<&Device> {
212        self.available_devices
213            .iter()
214            .find(|d| d.device_type == device_type && d.index == index)
215    }
216
217    /// Count devices of a specific type.
218    pub fn count_devices(&self, device_type: DeviceType) -> usize {
219        self.available_devices
220            .iter()
221            .filter(|d| d.device_type == device_type)
222            .count()
223    }
224
225    /// Check if any GPU devices are available.
226    pub fn has_gpu(&self) -> bool {
227        self.available_devices.iter().any(|d| d.is_gpu())
228    }
229}
230
231impl Default for SystemDeviceManager {
232    fn default() -> Self {
233        Self::new()
234    }
235}
236
237/// Device-related errors.
238#[derive(Debug, Clone, thiserror::Error)]
239pub enum DeviceError {
240    /// Device is not available
241    #[error("Device not available: {0}")]
242    DeviceNotAvailable(Device),
243
244    /// Device memory allocation failed
245    #[error("Device memory allocation failed: {0}")]
246    AllocationFailed(String),
247
248    /// Device synchronization failed
249    #[error("Device synchronization failed: {0}")]
250    SyncFailed(String),
251
252    /// Unsupported device operation
253    #[error("Unsupported operation on device {device}: {operation}")]
254    UnsupportedOperation { device: Device, operation: String },
255}
256
257#[cfg(test)]
258mod tests {
259    use super::*;
260
261    #[test]
262    fn test_device_type_properties() {
263        assert!(DeviceType::Cpu.is_cpu());
264        assert!(!DeviceType::Cpu.is_gpu());
265
266        assert!(DeviceType::Cuda.is_gpu());
267        assert!(!DeviceType::Cuda.is_cpu());
268
269        assert!(DeviceType::Metal.is_gpu());
270        assert!(DeviceType::Vulkan.is_gpu());
271        assert!(DeviceType::Rocm.is_gpu());
272    }
273
274    #[test]
275    fn test_device_type_display() {
276        assert_eq!(DeviceType::Cpu.to_string(), "CPU");
277        assert_eq!(DeviceType::Cuda.to_string(), "CUDA");
278        assert_eq!(DeviceType::Metal.to_string(), "Metal");
279    }
280
281    #[test]
282    fn test_device_creation() {
283        let cpu = Device::cpu();
284        assert!(cpu.is_cpu());
285        assert_eq!(cpu.index(), 0);
286
287        let cuda = Device::cuda(1);
288        assert!(cuda.is_gpu());
289        assert_eq!(cuda.index(), 1);
290        assert_eq!(cuda.device_type(), DeviceType::Cuda);
291    }
292
293    #[test]
294    fn test_device_default() {
295        let device = Device::default();
296        assert!(device.is_cpu());
297        assert_eq!(device.index(), 0);
298    }
299
300    #[test]
301    fn test_device_display() {
302        assert_eq!(Device::cpu().to_string(), "CPU");
303        assert_eq!(Device::cuda(0).to_string(), "CUDA:0");
304        assert_eq!(Device::cuda(1).to_string(), "CUDA:1");
305        assert_eq!(Device::metal().to_string(), "Metal:0");
306    }
307
308    #[test]
309    fn test_device_manager_creation() {
310        let manager = SystemDeviceManager::new();
311        assert!(!manager.available_devices().is_empty());
312        assert!(manager.default_device().is_cpu());
313    }
314
315    #[test]
316    fn test_device_manager_queries() {
317        let manager = SystemDeviceManager::new();
318
319        // CPU should always be available
320        assert!(manager.is_available(&Device::cpu()));
321        assert_eq!(manager.count_devices(DeviceType::Cpu), 1);
322
323        // Check default device
324        assert_eq!(manager.default_device(), &Device::cpu());
325    }
326
327    #[test]
328    fn test_device_manager_set_default() {
329        let mut manager = SystemDeviceManager::new();
330        let cpu = Device::cpu();
331
332        // Setting to an available device should succeed
333        assert!(manager.set_default_device(cpu.clone()).is_ok());
334        assert_eq!(manager.default_device(), &cpu);
335
336        // Setting to an unavailable device should fail
337        let cuda = Device::cuda(99);
338        assert!(manager.set_default_device(cuda).is_err());
339    }
340
341    #[test]
342    fn test_device_manager_get_device() {
343        let manager = SystemDeviceManager::new();
344
345        // Should find CPU
346        let cpu = manager.get_device(DeviceType::Cpu, 0);
347        assert!(cpu.is_some());
348        assert_eq!(cpu.expect("cpu device expected"), &Device::cpu());
349
350        // Should not find non-existent devices
351        let cuda = manager.get_device(DeviceType::Cuda, 0);
352        assert!(cuda.is_none());
353    }
354
355    #[test]
356    fn test_device_error_display() {
357        let err = DeviceError::DeviceNotAvailable(Device::cuda(0));
358        assert!(err.to_string().contains("not available"));
359
360        let err = DeviceError::AllocationFailed("out of memory".to_string());
361        assert!(err.to_string().contains("allocation failed"));
362    }
363}