trustformers_core/
device.rs1use serde::{Deserialize, Serialize};
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
10pub enum Device {
11 #[default]
13 CPU,
14 CUDA(usize),
16 Metal(usize),
18 ROCm(usize),
20 WebGPU,
22}
23
24impl Device {
25 pub fn is_gpu(&self) -> bool {
27 !matches!(self, Device::CPU)
28 }
29
30 pub fn is_metal(&self) -> bool {
32 matches!(self, Device::Metal(_))
33 }
34
35 pub fn is_cuda(&self) -> bool {
37 matches!(self, Device::CUDA(_))
38 }
39
40 pub fn is_cpu(&self) -> bool {
42 matches!(self, Device::CPU)
43 }
44
45 pub fn device_id(&self) -> Option<usize> {
47 match self {
48 Device::CUDA(id) | Device::Metal(id) | Device::ROCm(id) => Some(*id),
49 _ => None,
50 }
51 }
52
53 #[cfg(all(target_os = "macos", feature = "metal"))]
55 pub fn metal_if_available(device_id: usize) -> Device {
56 #[cfg(all(target_os = "macos", feature = "metal"))]
59 {
60 use metal::Device as MetalDevice;
61 if MetalDevice::system_default().is_some() {
62 Device::Metal(device_id)
63 } else {
64 Device::CPU
65 }
66 }
67
68 #[cfg(not(all(target_os = "macos", feature = "metal")))]
69 Device::CPU
70 }
71
72 #[cfg(not(all(target_os = "macos", feature = "metal")))]
73 pub fn metal_if_available(_device_id: usize) -> Device {
74 Device::CPU
75 }
76
77 #[cfg(feature = "cuda")]
79 pub fn cuda_if_available(device_id: usize) -> Device {
80 use scirs2_core::simd_ops::PlatformCapabilities;
81
82 let caps = PlatformCapabilities::detect();
83 if caps.cuda_available {
84 Device::CUDA(device_id)
85 } else {
86 Device::CPU
87 }
88 }
89
90 #[cfg(not(feature = "cuda"))]
91 pub fn cuda_if_available(_device_id: usize) -> Device {
92 Device::CPU
93 }
94
95 pub fn best_available() -> Device {
97 #[cfg(all(target_os = "macos", feature = "metal"))]
98 {
99 use scirs2_core::simd_ops::PlatformCapabilities;
100 let caps = PlatformCapabilities::detect();
101 if caps.metal_available {
102 return Device::Metal(0);
103 }
104 }
105
106 #[cfg(feature = "cuda")]
107 {
108 use scirs2_core::simd_ops::PlatformCapabilities;
109 let caps = PlatformCapabilities::detect();
110 if caps.cuda_available {
111 return Device::CUDA(0);
112 }
113 }
114
115 Device::CPU
116 }
117}
118
119impl std::fmt::Display for Device {
120 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
121 match self {
122 Device::CPU => write!(f, "CPU"),
123 Device::CUDA(id) => write!(f, "CUDA:{}", id),
124 Device::Metal(id) => write!(f, "Metal:{}", id),
125 Device::ROCm(id) => write!(f, "ROCm:{}", id),
126 Device::WebGPU => write!(f, "WebGPU"),
127 }
128 }
129}
130
131#[cfg(test)]
132mod tests {
133 use super::*;
134
135 #[test]
136 fn test_device_checks() {
137 assert!(Device::CPU.is_cpu());
138 assert!(!Device::CPU.is_gpu());
139 assert!(!Device::CPU.is_metal());
140
141 assert!(Device::Metal(0).is_metal());
142 assert!(Device::Metal(0).is_gpu());
143 assert!(!Device::Metal(0).is_cpu());
144
145 assert!(Device::CUDA(0).is_cuda());
146 assert!(Device::CUDA(0).is_gpu());
147 }
148
149 #[test]
150 fn test_device_id() {
151 assert_eq!(Device::CPU.device_id(), None);
152 assert_eq!(Device::Metal(0).device_id(), Some(0));
153 assert_eq!(Device::CUDA(1).device_id(), Some(1));
154 }
155
156 #[test]
157 fn test_device_display() {
158 assert_eq!(Device::CPU.to_string(), "CPU");
159 assert_eq!(Device::Metal(0).to_string(), "Metal:0");
160 assert_eq!(Device::CUDA(1).to_string(), "CUDA:1");
161 }
162}