Skip to main content

trustformers_core/
device.rs

1//! Device abstraction for hardware acceleration
2//!
3//! This module provides a simple Device enum for specifying where computations
4//! should be executed (CPU, CUDA GPU, Metal GPU, etc.).
5
6use serde::{Deserialize, Serialize};
7
8/// Device specification for tensor operations
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
10pub enum Device {
11    /// CPU execution
12    #[default]
13    CPU,
14    /// NVIDIA CUDA GPU (with device ID)
15    CUDA(usize),
16    /// Apple Metal GPU (with device ID)
17    Metal(usize),
18    /// AMD ROCm GPU (with device ID)
19    ROCm(usize),
20    /// WebGPU
21    WebGPU,
22}
23
24impl Device {
25    /// Returns true if this device is a GPU
26    pub fn is_gpu(&self) -> bool {
27        !matches!(self, Device::CPU)
28    }
29
30    /// Returns true if this device is Metal GPU
31    pub fn is_metal(&self) -> bool {
32        matches!(self, Device::Metal(_))
33    }
34
35    /// Returns true if this device is CUDA GPU
36    pub fn is_cuda(&self) -> bool {
37        matches!(self, Device::CUDA(_))
38    }
39
40    /// Returns true if this device is CPU
41    pub fn is_cpu(&self) -> bool {
42        matches!(self, Device::CPU)
43    }
44
45    /// Returns the device ID for GPU devices
46    pub fn device_id(&self) -> Option<usize> {
47        match self {
48            Device::CUDA(id) | Device::Metal(id) | Device::ROCm(id) => Some(*id),
49            _ => None,
50        }
51    }
52
53    /// Create a Metal device, or CPU if Metal is not available
54    #[cfg(all(target_os = "macos", feature = "metal"))]
55    pub fn metal_if_available(device_id: usize) -> Device {
56        // Try to initialize Metal backend to check availability
57        // This is more reliable than scirs2's platform detection
58        #[cfg(all(target_os = "macos", feature = "metal"))]
59        {
60            use metal::Device as MetalDevice;
61            if MetalDevice::system_default().is_some() {
62                Device::Metal(device_id)
63            } else {
64                Device::CPU
65            }
66        }
67
68        #[cfg(not(all(target_os = "macos", feature = "metal")))]
69        Device::CPU
70    }
71
72    #[cfg(not(all(target_os = "macos", feature = "metal")))]
73    pub fn metal_if_available(_device_id: usize) -> Device {
74        Device::CPU
75    }
76
77    /// Create a CUDA device, or CPU if CUDA is not available
78    #[cfg(feature = "cuda")]
79    pub fn cuda_if_available(device_id: usize) -> Device {
80        use scirs2_core::simd_ops::PlatformCapabilities;
81
82        let caps = PlatformCapabilities::detect();
83        if caps.cuda_available {
84            Device::CUDA(device_id)
85        } else {
86            Device::CPU
87        }
88    }
89
90    #[cfg(not(feature = "cuda"))]
91    pub fn cuda_if_available(_device_id: usize) -> Device {
92        Device::CPU
93    }
94
95    /// Get the best available device (prefers GPU over CPU)
96    pub fn best_available() -> Device {
97        #[cfg(all(target_os = "macos", feature = "metal"))]
98        {
99            use scirs2_core::simd_ops::PlatformCapabilities;
100            let caps = PlatformCapabilities::detect();
101            if caps.metal_available {
102                return Device::Metal(0);
103            }
104        }
105
106        #[cfg(feature = "cuda")]
107        {
108            use scirs2_core::simd_ops::PlatformCapabilities;
109            let caps = PlatformCapabilities::detect();
110            if caps.cuda_available {
111                return Device::CUDA(0);
112            }
113        }
114
115        Device::CPU
116    }
117}
118
119impl std::fmt::Display for Device {
120    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
121        match self {
122            Device::CPU => write!(f, "CPU"),
123            Device::CUDA(id) => write!(f, "CUDA:{}", id),
124            Device::Metal(id) => write!(f, "Metal:{}", id),
125            Device::ROCm(id) => write!(f, "ROCm:{}", id),
126            Device::WebGPU => write!(f, "WebGPU"),
127        }
128    }
129}
130
131#[cfg(test)]
132mod tests {
133    use super::*;
134
135    #[test]
136    fn test_device_checks() {
137        assert!(Device::CPU.is_cpu());
138        assert!(!Device::CPU.is_gpu());
139        assert!(!Device::CPU.is_metal());
140
141        assert!(Device::Metal(0).is_metal());
142        assert!(Device::Metal(0).is_gpu());
143        assert!(!Device::Metal(0).is_cpu());
144
145        assert!(Device::CUDA(0).is_cuda());
146        assert!(Device::CUDA(0).is_gpu());
147    }
148
149    #[test]
150    fn test_device_id() {
151        assert_eq!(Device::CPU.device_id(), None);
152        assert_eq!(Device::Metal(0).device_id(), Some(0));
153        assert_eq!(Device::CUDA(1).device_id(), Some(1));
154    }
155
156    #[test]
157    fn test_device_display() {
158        assert_eq!(Device::CPU.to_string(), "CPU");
159        assert_eq!(Device::Metal(0).to_string(), "Metal:0");
160        assert_eq!(Device::CUDA(1).to_string(), "CUDA:1");
161    }
162}