Skip to main content

tensorlogic_scirs_backend/
device.rs

1//! Device management for tensor computations.
2//!
3//! This module provides abstractions for managing compute devices
4//! (CPU, GPU, etc.) and tensor placement.
5
6use std::fmt;
7
8/// Compute device type.
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
10pub enum DeviceType {
11    /// CPU device (default)
12    Cpu,
13
14    /// CUDA GPU device
15    Cuda,
16
17    /// Metal GPU device (Apple)
18    Metal,
19
20    /// Vulkan compute device
21    Vulkan,
22
23    /// ROCm GPU device (AMD)
24    Rocm,
25}
26
27impl DeviceType {
28    /// Returns true if this is a GPU device.
29    pub fn is_gpu(&self) -> bool {
30        matches!(
31            self,
32            DeviceType::Cuda | DeviceType::Metal | DeviceType::Vulkan | DeviceType::Rocm
33        )
34    }
35
36    /// Returns true if this is a CPU device.
37    pub fn is_cpu(&self) -> bool {
38        matches!(self, DeviceType::Cpu)
39    }
40
41    /// Returns the name of this device type.
42    pub fn name(&self) -> &'static str {
43        match self {
44            DeviceType::Cpu => "CPU",
45            DeviceType::Cuda => "CUDA",
46            DeviceType::Metal => "Metal",
47            DeviceType::Vulkan => "Vulkan",
48            DeviceType::Rocm => "ROCm",
49        }
50    }
51}
52
53impl fmt::Display for DeviceType {
54    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
55        write!(f, "{}", self.name())
56    }
57}
58
59/// A specific compute device.
60#[derive(Debug, Clone, PartialEq, Eq, Hash)]
61pub struct Device {
62    /// Device type
63    pub device_type: DeviceType,
64
65    /// Device index (for multi-GPU systems)
66    pub index: usize,
67}
68
69impl Device {
70    /// Create a CPU device.
71    pub fn cpu() -> Self {
72        Self {
73            device_type: DeviceType::Cpu,
74            index: 0,
75        }
76    }
77
78    /// Create a CUDA device with the given index.
79    pub fn cuda(index: usize) -> Self {
80        Self {
81            device_type: DeviceType::Cuda,
82            index,
83        }
84    }
85
86    /// Create a Metal device.
87    pub fn metal() -> Self {
88        Self {
89            device_type: DeviceType::Metal,
90            index: 0,
91        }
92    }
93
94    /// Create a Vulkan device with the given index.
95    pub fn vulkan(index: usize) -> Self {
96        Self {
97            device_type: DeviceType::Vulkan,
98            index,
99        }
100    }
101
102    /// Create a ROCm device with the given index.
103    pub fn rocm(index: usize) -> Self {
104        Self {
105            device_type: DeviceType::Rocm,
106            index,
107        }
108    }
109
110    /// Returns true if this is a CPU device.
111    pub fn is_cpu(&self) -> bool {
112        self.device_type.is_cpu()
113    }
114
115    /// Returns true if this is a GPU device.
116    pub fn is_gpu(&self) -> bool {
117        self.device_type.is_gpu()
118    }
119
120    /// Returns the device type.
121    pub fn device_type(&self) -> DeviceType {
122        self.device_type
123    }
124
125    /// Returns the device index.
126    pub fn index(&self) -> usize {
127        self.index
128    }
129}
130
131impl Default for Device {
132    fn default() -> Self {
133        Self::cpu()
134    }
135}
136
137impl fmt::Display for Device {
138    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
139        if self.index == 0 && self.is_cpu() {
140            write!(f, "{}", self.device_type)
141        } else {
142            write!(f, "{}:{}", self.device_type, self.index)
143        }
144    }
145}
146
147/// Device manager for querying available devices.
148#[derive(Debug, Clone)]
149pub struct DeviceManager {
150    /// List of available devices
151    available_devices: Vec<Device>,
152
153    /// Default device
154    default_device: Device,
155}
156
157impl DeviceManager {
158    /// Create a new device manager.
159    ///
160    /// This queries the system for available devices, including CUDA GPUs
161    /// if available via nvidia-smi.
162    pub fn new() -> Self {
163        #[cfg(test)] // In tests, only CPU is available
164        let available_devices = vec![Device::cpu()];
165
166        #[cfg(not(test))] // In production, detect CUDA devices
167        let available_devices = {
168            let mut devices = vec![Device::cpu()];
169            let cuda_devices = crate::cuda_detect::detect_cuda_devices();
170            for cuda_info in cuda_devices {
171                devices.push(Device::cuda(cuda_info.index));
172            }
173            devices
174        };
175
176        Self {
177            available_devices: available_devices.clone(),
178            default_device: available_devices[0].clone(),
179        }
180    }
181
182    /// Get the list of available devices.
183    pub fn available_devices(&self) -> &[Device] {
184        &self.available_devices
185    }
186
187    /// Get the default device.
188    pub fn default_device(&self) -> &Device {
189        &self.default_device
190    }
191
192    /// Set the default device.
193    pub fn set_default_device(&mut self, device: Device) -> Result<(), DeviceError> {
194        if !self.available_devices.contains(&device) {
195            return Err(DeviceError::DeviceNotAvailable(device));
196        }
197        self.default_device = device;
198        Ok(())
199    }
200
201    /// Check if a device is available.
202    pub fn is_available(&self, device: &Device) -> bool {
203        self.available_devices.contains(device)
204    }
205
206    /// Get a device by type and index.
207    pub fn get_device(&self, device_type: DeviceType, index: usize) -> Option<&Device> {
208        self.available_devices
209            .iter()
210            .find(|d| d.device_type == device_type && d.index == index)
211    }
212
213    /// Count devices of a specific type.
214    pub fn count_devices(&self, device_type: DeviceType) -> usize {
215        self.available_devices
216            .iter()
217            .filter(|d| d.device_type == device_type)
218            .count()
219    }
220
221    /// Check if any GPU devices are available.
222    pub fn has_gpu(&self) -> bool {
223        self.available_devices.iter().any(|d| d.is_gpu())
224    }
225}
226
227impl Default for DeviceManager {
228    fn default() -> Self {
229        Self::new()
230    }
231}
232
233/// Device-related errors.
234#[derive(Debug, Clone, thiserror::Error)]
235pub enum DeviceError {
236    /// Device is not available
237    #[error("Device not available: {0}")]
238    DeviceNotAvailable(Device),
239
240    /// Device memory allocation failed
241    #[error("Device memory allocation failed: {0}")]
242    AllocationFailed(String),
243
244    /// Device synchronization failed
245    #[error("Device synchronization failed: {0}")]
246    SyncFailed(String),
247
248    /// Unsupported device operation
249    #[error("Unsupported operation on device {device}: {operation}")]
250    UnsupportedOperation { device: Device, operation: String },
251}
252
253#[cfg(test)]
254mod tests {
255    use super::*;
256
257    #[test]
258    fn test_device_type_properties() {
259        assert!(DeviceType::Cpu.is_cpu());
260        assert!(!DeviceType::Cpu.is_gpu());
261
262        assert!(DeviceType::Cuda.is_gpu());
263        assert!(!DeviceType::Cuda.is_cpu());
264
265        assert!(DeviceType::Metal.is_gpu());
266        assert!(DeviceType::Vulkan.is_gpu());
267        assert!(DeviceType::Rocm.is_gpu());
268    }
269
270    #[test]
271    fn test_device_type_display() {
272        assert_eq!(DeviceType::Cpu.to_string(), "CPU");
273        assert_eq!(DeviceType::Cuda.to_string(), "CUDA");
274        assert_eq!(DeviceType::Metal.to_string(), "Metal");
275    }
276
277    #[test]
278    fn test_device_creation() {
279        let cpu = Device::cpu();
280        assert!(cpu.is_cpu());
281        assert_eq!(cpu.index(), 0);
282
283        let cuda = Device::cuda(1);
284        assert!(cuda.is_gpu());
285        assert_eq!(cuda.index(), 1);
286        assert_eq!(cuda.device_type(), DeviceType::Cuda);
287    }
288
289    #[test]
290    fn test_device_default() {
291        let device = Device::default();
292        assert!(device.is_cpu());
293        assert_eq!(device.index(), 0);
294    }
295
296    #[test]
297    fn test_device_display() {
298        assert_eq!(Device::cpu().to_string(), "CPU");
299        assert_eq!(Device::cuda(0).to_string(), "CUDA:0");
300        assert_eq!(Device::cuda(1).to_string(), "CUDA:1");
301        assert_eq!(Device::metal().to_string(), "Metal:0");
302    }
303
304    #[test]
305    fn test_device_manager_creation() {
306        let manager = DeviceManager::new();
307        assert!(!manager.available_devices().is_empty());
308        assert!(manager.default_device().is_cpu());
309    }
310
311    #[test]
312    fn test_device_manager_queries() {
313        let manager = DeviceManager::new();
314
315        // CPU should always be available
316        assert!(manager.is_available(&Device::cpu()));
317        assert_eq!(manager.count_devices(DeviceType::Cpu), 1);
318
319        // Check default device
320        assert_eq!(manager.default_device(), &Device::cpu());
321    }
322
323    #[test]
324    fn test_device_manager_set_default() {
325        let mut manager = DeviceManager::new();
326        let cpu = Device::cpu();
327
328        // Setting to an available device should succeed
329        assert!(manager.set_default_device(cpu.clone()).is_ok());
330        assert_eq!(manager.default_device(), &cpu);
331
332        // Setting to an unavailable device should fail
333        let cuda = Device::cuda(99);
334        assert!(manager.set_default_device(cuda).is_err());
335    }
336
337    #[test]
338    fn test_device_manager_get_device() {
339        let manager = DeviceManager::new();
340
341        // Should find CPU
342        let cpu = manager.get_device(DeviceType::Cpu, 0);
343        assert!(cpu.is_some());
344        assert_eq!(cpu.unwrap(), &Device::cpu());
345
346        // Should not find non-existent devices
347        let cuda = manager.get_device(DeviceType::Cuda, 0);
348        assert!(cuda.is_none());
349    }
350
351    #[test]
352    fn test_device_error_display() {
353        let err = DeviceError::DeviceNotAvailable(Device::cuda(0));
354        assert!(err.to_string().contains("not available"));
355
356        let err = DeviceError::AllocationFailed("out of memory".to_string());
357        assert!(err.to_string().contains("allocation failed"));
358    }
359}