1use std::path::Path;
2use std::str::FromStr;
3
4use tracing::debug;
5
6use uv_pep440::Version;
7use uv_static::EnvVars;
8
9#[cfg(windows)]
10use serde::Deserialize;
11#[cfg(windows)]
12use wmi::{COMLibrary, WMIConnection};
13
14#[derive(Debug, thiserror::Error)]
15pub enum AcceleratorError {
16 #[error(transparent)]
17 Io(#[from] std::io::Error),
18 #[error(transparent)]
19 Version(#[from] uv_pep440::VersionParseError),
20 #[error(transparent)]
21 Utf8(#[from] std::string::FromUtf8Error),
22 #[error(transparent)]
23 ParseInt(#[from] std::num::ParseIntError),
24 #[error("Unknown AMD GPU architecture: {0}")]
25 UnknownAmdGpuArchitecture(String),
26}
27
28#[derive(Debug, Clone, Eq, PartialEq)]
29pub enum Accelerator {
30 Cuda { driver_version: Version },
34 Amd {
39 gpu_architecture: AmdGpuArchitecture,
40 },
41 Xpu,
45}
46
47impl std::fmt::Display for Accelerator {
48 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
49 match self {
50 Self::Cuda { driver_version } => write!(f, "CUDA {driver_version}"),
51 Self::Amd { gpu_architecture } => write!(f, "AMD {gpu_architecture}"),
52 Self::Xpu => write!(f, "Intel GPU (XPU)"),
53 }
54 }
55}
56
57impl Accelerator {
58 pub fn detect() -> Result<Option<Self>, AcceleratorError> {
70 const PCI_BASE_CLASS_MASK: u32 = 0x00ff_0000;
72 const PCI_BASE_CLASS_DISPLAY: u32 = 0x0003_0000;
73 const PCI_VENDOR_ID_INTEL: u32 = 0x8086;
74
75 if let Ok(driver_version) = std::env::var(EnvVars::UV_CUDA_DRIVER_VERSION) {
77 let driver_version = Version::from_str(&driver_version)?;
78 debug!("Detected CUDA driver version from `UV_CUDA_DRIVER_VERSION`: {driver_version}");
79 return Ok(Some(Self::Cuda { driver_version }));
80 }
81
82 if let Ok(gpu_architecture) = std::env::var(EnvVars::UV_AMD_GPU_ARCHITECTURE) {
84 let gpu_architecture = AmdGpuArchitecture::from_str(&gpu_architecture)?;
85 debug!(
86 "Detected AMD GPU architecture from `UV_AMD_GPU_ARCHITECTURE`: {gpu_architecture}"
87 );
88 return Ok(Some(Self::Amd { gpu_architecture }));
89 }
90
91 match fs_err::read_to_string("/sys/module/nvidia/version") {
93 Ok(content) => {
94 return match parse_sys_module_nvidia_version(&content) {
95 Ok(driver_version) => {
96 debug!(
97 "Detected CUDA driver version from `/sys/module/nvidia/version`: {driver_version}"
98 );
99 Ok(Some(Self::Cuda { driver_version }))
100 }
101 Err(e) => Err(e),
102 };
103 }
104 Err(e) if e.kind() == std::io::ErrorKind::NotFound => {}
105 Err(e) => return Err(e.into()),
106 }
107
108 match fs_err::read_to_string("/proc/driver/nvidia/version") {
110 Ok(content) => match parse_proc_driver_nvidia_version(&content) {
111 Ok(Some(driver_version)) => {
112 debug!(
113 "Detected CUDA driver version from `/proc/driver/nvidia/version`: {driver_version}"
114 );
115 return Ok(Some(Self::Cuda { driver_version }));
116 }
117 Ok(None) => {
118 debug!(
119 "Failed to parse CUDA driver version from `/proc/driver/nvidia/version`"
120 );
121 }
122 Err(e) => return Err(e),
123 },
124 Err(e) if e.kind() == std::io::ErrorKind::NotFound => {}
125 Err(e) => return Err(e.into()),
126 }
127
128 if let Ok(output) = std::process::Command::new("nvidia-smi")
130 .arg("--query-gpu=driver_version")
131 .arg("--format=csv,noheader")
132 .output()
133 {
134 if output.status.success() {
135 let stdout = String::from_utf8(output.stdout)?;
136 if let Some(first_line) = stdout.lines().next() {
137 let driver_version = Version::from_str(first_line.trim())?;
138 debug!("Detected CUDA driver version from `nvidia-smi`: {driver_version}");
139 return Ok(Some(Self::Cuda { driver_version }));
140 }
141 }
142
143 debug!(
144 "Failed to query CUDA driver version with `nvidia-smi` with status `{}`: {}",
145 output.status,
146 String::from_utf8_lossy(&output.stderr)
147 );
148 }
149
150 if let Ok(output) = std::process::Command::new("rocm_agent_enumerator").output() {
154 if output.status.success() {
155 let stdout = String::from_utf8(output.stdout)?;
156 if let Some(gpu_architecture) = stdout
157 .lines()
158 .map(str::trim)
159 .filter_map(|line| AmdGpuArchitecture::from_str(line).ok())
160 .min()
161 {
162 debug!(
163 "Detected AMD GPU architecture from `rocm_agent_enumerator`: {gpu_architecture}"
164 );
165 return Ok(Some(Self::Amd { gpu_architecture }));
166 }
167 } else {
168 debug!(
169 "Failed to query AMD GPU architecture with `rocm_agent_enumerator` with status `{}`: {}",
170 output.status,
171 String::from_utf8_lossy(&output.stderr)
172 );
173 }
174 }
175
176 match fs_err::read_dir("/sys/bus/pci/devices") {
178 Ok(entries) => {
179 for entry in entries.flatten() {
180 match parse_pci_device_ids(&entry.path()) {
181 Ok((class, vendor)) => {
182 if (class & PCI_BASE_CLASS_MASK) == PCI_BASE_CLASS_DISPLAY
183 && vendor == PCI_VENDOR_ID_INTEL
184 {
185 debug!("Detected Intel GPU from PCI: vendor=0x{:04x}", vendor);
186 return Ok(Some(Self::Xpu));
187 }
188 }
189 Err(e) => {
190 debug!("Failed to parse PCI device IDs: {e}");
191 }
192 }
193 }
194 }
195 Err(e) if e.kind() == std::io::ErrorKind::NotFound => {}
196 Err(e) => return Err(e.into()),
197 }
198
199 #[cfg(windows)]
201 {
202 #[derive(Deserialize, Debug)]
203 #[serde(rename = "Win32_VideoController")]
204 #[serde(rename_all = "PascalCase")]
205 struct VideoController {
206 #[serde(rename = "PNPDeviceID")]
207 pnp_device_id: Option<String>,
208 name: Option<String>,
209 }
210
211 match COMLibrary::new() {
212 Ok(com_library) => match WMIConnection::new(com_library) {
213 Ok(wmi_connection) => match wmi_connection.query::<VideoController>() {
214 Ok(gpu_controllers) => {
215 for gpu_controller in gpu_controllers {
216 if let Some(pnp_device_id) = &gpu_controller.pnp_device_id {
217 if pnp_device_id
218 .contains(&format!("VEN_{PCI_VENDOR_ID_INTEL:04X}"))
219 {
220 debug!(
221 "Detected Intel GPU from WMI: PNPDeviceID={}, Name={:?}",
222 pnp_device_id, gpu_controller.name
223 );
224 return Ok(Some(Self::Xpu));
225 }
226 }
227 }
228 }
229 Err(e) => {
230 debug!("Failed to query WMI for video controllers: {e}");
231 }
232 },
233 Err(e) => {
234 debug!("Failed to create WMI connection: {e}");
235 }
236 },
237 Err(e) => {
238 debug!("Failed to initialize COM library: {e}");
239 }
240 }
241 }
242
243 debug!("Failed to detect GPU driver version");
244
245 Ok(None)
246 }
247}
248
249fn parse_sys_module_nvidia_version(content: &str) -> Result<Version, AcceleratorError> {
251 let driver_version = Version::from_str(content.trim())?;
256 Ok(driver_version)
257}
258
259fn parse_proc_driver_nvidia_version(content: &str) -> Result<Option<Version>, AcceleratorError> {
261 let Some(version) = content.split(" ").nth(1) else {
267 return Ok(None);
268 };
269 let driver_version = Version::from_str(version.trim())?;
270 Ok(Some(driver_version))
271}
272
273fn parse_pci_device_ids(device_path: &Path) -> Result<(u32, u32), AcceleratorError> {
275 let class_content = fs_err::read_to_string(device_path.join("class"))?;
281 let pci_class = u32::from_str_radix(class_content.trim().trim_start_matches("0x"), 16)?;
282
283 let vendor_content = fs_err::read_to_string(device_path.join("vendor"))?;
284 let pci_vendor = u32::from_str_radix(vendor_content.trim().trim_start_matches("0x"), 16)?;
285
286 Ok((pci_class, pci_vendor))
287}
288
289#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd)]
293pub enum AmdGpuArchitecture {
294 Gfx900,
295 Gfx906,
296 Gfx908,
297 Gfx90a,
298 Gfx942,
299 Gfx1030,
300 Gfx1100,
301 Gfx1101,
302 Gfx1102,
303 Gfx1200,
304 Gfx1201,
305}
306
307impl FromStr for AmdGpuArchitecture {
308 type Err = AcceleratorError;
309
310 fn from_str(s: &str) -> Result<Self, Self::Err> {
311 match s {
312 "gfx900" => Ok(Self::Gfx900),
313 "gfx906" => Ok(Self::Gfx906),
314 "gfx908" => Ok(Self::Gfx908),
315 "gfx90a" => Ok(Self::Gfx90a),
316 "gfx942" => Ok(Self::Gfx942),
317 "gfx1030" => Ok(Self::Gfx1030),
318 "gfx1100" => Ok(Self::Gfx1100),
319 "gfx1101" => Ok(Self::Gfx1101),
320 "gfx1102" => Ok(Self::Gfx1102),
321 "gfx1200" => Ok(Self::Gfx1200),
322 "gfx1201" => Ok(Self::Gfx1201),
323 _ => Err(AcceleratorError::UnknownAmdGpuArchitecture(s.to_string())),
324 }
325 }
326}
327
328impl std::fmt::Display for AmdGpuArchitecture {
329 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
330 match self {
331 Self::Gfx900 => write!(f, "gfx900"),
332 Self::Gfx906 => write!(f, "gfx906"),
333 Self::Gfx908 => write!(f, "gfx908"),
334 Self::Gfx90a => write!(f, "gfx90a"),
335 Self::Gfx942 => write!(f, "gfx942"),
336 Self::Gfx1030 => write!(f, "gfx1030"),
337 Self::Gfx1100 => write!(f, "gfx1100"),
338 Self::Gfx1101 => write!(f, "gfx1101"),
339 Self::Gfx1102 => write!(f, "gfx1102"),
340 Self::Gfx1200 => write!(f, "gfx1200"),
341 Self::Gfx1201 => write!(f, "gfx1201"),
342 }
343 }
344}
345
346#[cfg(test)]
347mod tests {
348 use super::*;
349
350 #[test]
351 fn proc_driver_nvidia_version() {
352 let content = "NVRM version: NVIDIA UNIX Open Kernel Module for x86_64 550.144.03 Release Build (dvs-builder@U16-I3-D08-1-2) Mon Dec 30 17:26:13 UTC 2024\nGCC version: gcc version 12.3.0 (Ubuntu 12.3.0-1ubuntu1~22.04)";
353 let result = parse_proc_driver_nvidia_version(content).unwrap();
354 assert_eq!(result, Some(Version::from_str("550.144.03").unwrap()));
355
356 let content = "NVRM version: NVIDIA UNIX x86_64 Kernel Module 375.74 Wed Jun 14 01:39:39 PDT 2017\nGCC version: gcc version 5.4.0 20160609 (Ubuntu 5.4.0-6ubuntu1~16.04.4)";
357 let result = parse_proc_driver_nvidia_version(content).unwrap();
358 assert_eq!(result, Some(Version::from_str("375.74").unwrap()));
359 }
360
361 #[test]
362 fn nvidia_smi_multi_gpu() {
363 let single_gpu = "572.60\n";
365 if let Some(first_line) = single_gpu.lines().next() {
366 let version = Version::from_str(first_line.trim()).unwrap();
367 assert_eq!(version, Version::from_str("572.60").unwrap());
368 }
369
370 let multi_gpu = "572.60\n572.60\n";
371 if let Some(first_line) = multi_gpu.lines().next() {
372 let version = Version::from_str(first_line.trim()).unwrap();
373 assert_eq!(version, Version::from_str("572.60").unwrap());
374 }
375 }
376}