uv_torch/
accelerator.rs

1use std::path::Path;
2use std::str::FromStr;
3
4use tracing::debug;
5
6use uv_pep440::Version;
7use uv_static::EnvVars;
8
9#[cfg(windows)]
10use serde::Deserialize;
11#[cfg(windows)]
12use wmi::{COMLibrary, WMIConnection};
13
14#[derive(Debug, thiserror::Error)]
15pub enum AcceleratorError {
16    #[error(transparent)]
17    Io(#[from] std::io::Error),
18    #[error(transparent)]
19    Version(#[from] uv_pep440::VersionParseError),
20    #[error(transparent)]
21    Utf8(#[from] std::string::FromUtf8Error),
22    #[error(transparent)]
23    ParseInt(#[from] std::num::ParseIntError),
24    #[error("Unknown AMD GPU architecture: {0}")]
25    UnknownAmdGpuArchitecture(String),
26}
27
28#[derive(Debug, Clone, Eq, PartialEq)]
29pub enum Accelerator {
30    /// The CUDA driver version (e.g., `550.144.03`).
31    ///
32    /// This is in contrast to the CUDA toolkit version (e.g., `12.8.0`).
33    Cuda { driver_version: Version },
34    /// The AMD GPU architecture (e.g., `gfx906`).
35    ///
36    /// This is in contrast to the user-space ROCm version (e.g., `6.4.0-47`) or the kernel-mode
37    /// driver version (e.g., `6.12.12`).
38    Amd {
39        gpu_architecture: AmdGpuArchitecture,
40    },
41    /// The Intel GPU (XPU).
42    ///
43    /// Currently, Intel GPUs do not depend on a driver or toolkit version at this level.
44    Xpu,
45}
46
47impl std::fmt::Display for Accelerator {
48    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
49        match self {
50            Self::Cuda { driver_version } => write!(f, "CUDA {driver_version}"),
51            Self::Amd { gpu_architecture } => write!(f, "AMD {gpu_architecture}"),
52            Self::Xpu => write!(f, "Intel GPU (XPU)"),
53        }
54    }
55}
56
57impl Accelerator {
58    /// Detect the GPU driver and/or architecture version from the system.
59    ///
60    /// Query, in order:
61    /// 1. The `UV_CUDA_DRIVER_VERSION` environment variable.
62    /// 2. The `UV_AMD_GPU_ARCHITECTURE` environment variable.
63    /// 3. `/sys/module/nvidia/version`, which contains the driver version (e.g., `550.144.03`).
64    /// 4. `/proc/driver/nvidia/version`, which contains the driver version among other information.
65    /// 5. `nvidia-smi --query-gpu=driver_version --format=csv,noheader`.
66    /// 6. `rocm_agent_enumerator`, which lists the AMD GPU architectures.
67    /// 7. `/sys/bus/pci/devices`, filtering for the Intel GPU via PCI.
68    /// 8. Windows Managmeent Instrumentation (WMI), filtering for the Intel GPU via PCI.
69    pub fn detect() -> Result<Option<Self>, AcceleratorError> {
70        // Constants used for PCI device detection.
71        const PCI_BASE_CLASS_MASK: u32 = 0x00ff_0000;
72        const PCI_BASE_CLASS_DISPLAY: u32 = 0x0003_0000;
73        const PCI_VENDOR_ID_INTEL: u32 = 0x8086;
74
75        // Read from `UV_CUDA_DRIVER_VERSION`.
76        if let Ok(driver_version) = std::env::var(EnvVars::UV_CUDA_DRIVER_VERSION) {
77            let driver_version = Version::from_str(&driver_version)?;
78            debug!("Detected CUDA driver version from `UV_CUDA_DRIVER_VERSION`: {driver_version}");
79            return Ok(Some(Self::Cuda { driver_version }));
80        }
81
82        // Read from `UV_AMD_GPU_ARCHITECTURE`.
83        if let Ok(gpu_architecture) = std::env::var(EnvVars::UV_AMD_GPU_ARCHITECTURE) {
84            let gpu_architecture = AmdGpuArchitecture::from_str(&gpu_architecture)?;
85            debug!(
86                "Detected AMD GPU architecture from `UV_AMD_GPU_ARCHITECTURE`: {gpu_architecture}"
87            );
88            return Ok(Some(Self::Amd { gpu_architecture }));
89        }
90
91        // Read from `/sys/module/nvidia/version`.
92        match fs_err::read_to_string("/sys/module/nvidia/version") {
93            Ok(content) => {
94                return match parse_sys_module_nvidia_version(&content) {
95                    Ok(driver_version) => {
96                        debug!(
97                            "Detected CUDA driver version from `/sys/module/nvidia/version`: {driver_version}"
98                        );
99                        Ok(Some(Self::Cuda { driver_version }))
100                    }
101                    Err(e) => Err(e),
102                };
103            }
104            Err(e) if e.kind() == std::io::ErrorKind::NotFound => {}
105            Err(e) => return Err(e.into()),
106        }
107
108        // Read from `/proc/driver/nvidia/version`
109        match fs_err::read_to_string("/proc/driver/nvidia/version") {
110            Ok(content) => match parse_proc_driver_nvidia_version(&content) {
111                Ok(Some(driver_version)) => {
112                    debug!(
113                        "Detected CUDA driver version from `/proc/driver/nvidia/version`: {driver_version}"
114                    );
115                    return Ok(Some(Self::Cuda { driver_version }));
116                }
117                Ok(None) => {
118                    debug!(
119                        "Failed to parse CUDA driver version from `/proc/driver/nvidia/version`"
120                    );
121                }
122                Err(e) => return Err(e),
123            },
124            Err(e) if e.kind() == std::io::ErrorKind::NotFound => {}
125            Err(e) => return Err(e.into()),
126        }
127
128        // Query `nvidia-smi`.
129        if let Ok(output) = std::process::Command::new("nvidia-smi")
130            .arg("--query-gpu=driver_version")
131            .arg("--format=csv,noheader")
132            .output()
133        {
134            if output.status.success() {
135                let stdout = String::from_utf8(output.stdout)?;
136                if let Some(first_line) = stdout.lines().next() {
137                    let driver_version = Version::from_str(first_line.trim())?;
138                    debug!("Detected CUDA driver version from `nvidia-smi`: {driver_version}");
139                    return Ok(Some(Self::Cuda { driver_version }));
140                }
141            }
142
143            debug!(
144                "Failed to query CUDA driver version with `nvidia-smi` with status `{}`: {}",
145                output.status,
146                String::from_utf8_lossy(&output.stderr)
147            );
148        }
149
150        // Query `rocm_agent_enumerator` to detect the AMD GPU architecture.
151        //
152        // See: https://rocm.docs.amd.com/projects/rocminfo/en/latest/how-to/use-rocm-agent-enumerator.html
153        if let Ok(output) = std::process::Command::new("rocm_agent_enumerator").output() {
154            if output.status.success() {
155                let stdout = String::from_utf8(output.stdout)?;
156                if let Some(gpu_architecture) = stdout
157                    .lines()
158                    .map(str::trim)
159                    .filter_map(|line| AmdGpuArchitecture::from_str(line).ok())
160                    .min()
161                {
162                    debug!(
163                        "Detected AMD GPU architecture from `rocm_agent_enumerator`: {gpu_architecture}"
164                    );
165                    return Ok(Some(Self::Amd { gpu_architecture }));
166                }
167            } else {
168                debug!(
169                    "Failed to query AMD GPU architecture with `rocm_agent_enumerator` with status `{}`: {}",
170                    output.status,
171                    String::from_utf8_lossy(&output.stderr)
172                );
173            }
174        }
175
176        // Read from `/sys/bus/pci/devices` to filter for Intel GPU via PCI.
177        match fs_err::read_dir("/sys/bus/pci/devices") {
178            Ok(entries) => {
179                for entry in entries.flatten() {
180                    match parse_pci_device_ids(&entry.path()) {
181                        Ok((class, vendor)) => {
182                            if (class & PCI_BASE_CLASS_MASK) == PCI_BASE_CLASS_DISPLAY
183                                && vendor == PCI_VENDOR_ID_INTEL
184                            {
185                                debug!("Detected Intel GPU from PCI: vendor=0x{:04x}", vendor);
186                                return Ok(Some(Self::Xpu));
187                            }
188                        }
189                        Err(e) => {
190                            debug!("Failed to parse PCI device IDs: {e}");
191                        }
192                    }
193                }
194            }
195            Err(e) if e.kind() == std::io::ErrorKind::NotFound => {}
196            Err(e) => return Err(e.into()),
197        }
198
199        // Detect Intel GPU via WMI on Windows
200        #[cfg(windows)]
201        {
202            #[derive(Deserialize, Debug)]
203            #[serde(rename = "Win32_VideoController")]
204            #[serde(rename_all = "PascalCase")]
205            struct VideoController {
206                #[serde(rename = "PNPDeviceID")]
207                pnp_device_id: Option<String>,
208                name: Option<String>,
209            }
210
211            match COMLibrary::new() {
212                Ok(com_library) => match WMIConnection::new(com_library) {
213                    Ok(wmi_connection) => match wmi_connection.query::<VideoController>() {
214                        Ok(gpu_controllers) => {
215                            for gpu_controller in gpu_controllers {
216                                if let Some(pnp_device_id) = &gpu_controller.pnp_device_id {
217                                    if pnp_device_id
218                                        .contains(&format!("VEN_{PCI_VENDOR_ID_INTEL:04X}"))
219                                    {
220                                        debug!(
221                                            "Detected Intel GPU from WMI: PNPDeviceID={}, Name={:?}",
222                                            pnp_device_id, gpu_controller.name
223                                        );
224                                        return Ok(Some(Self::Xpu));
225                                    }
226                                }
227                            }
228                        }
229                        Err(e) => {
230                            debug!("Failed to query WMI for video controllers: {e}");
231                        }
232                    },
233                    Err(e) => {
234                        debug!("Failed to create WMI connection: {e}");
235                    }
236                },
237                Err(e) => {
238                    debug!("Failed to initialize COM library: {e}");
239                }
240            }
241        }
242
243        debug!("Failed to detect GPU driver version");
244
245        Ok(None)
246    }
247}
248
249/// Parse the CUDA driver version from the content of `/sys/module/nvidia/version`.
250fn parse_sys_module_nvidia_version(content: &str) -> Result<Version, AcceleratorError> {
251    // Parse, e.g.:
252    // ```text
253    // 550.144.03
254    // ```
255    let driver_version = Version::from_str(content.trim())?;
256    Ok(driver_version)
257}
258
259/// Parse the CUDA driver version from the content of `/proc/driver/nvidia/version`.
260fn parse_proc_driver_nvidia_version(content: &str) -> Result<Option<Version>, AcceleratorError> {
261    // Parse, e.g.:
262    // ```text
263    // NVRM version: NVIDIA UNIX Open Kernel Module for x86_64  550.144.03  Release Build  (dvs-builder@U16-I3-D08-1-2)  Mon Dec 30 17:26:13 UTC 2024
264    // GCC version:  gcc version 12.3.0 (Ubuntu 12.3.0-1ubuntu1~22.04)
265    // ```
266    let Some(version) = content.split("  ").nth(1) else {
267        return Ok(None);
268    };
269    let driver_version = Version::from_str(version.trim())?;
270    Ok(Some(driver_version))
271}
272
273/// Reads and parses the PCI class and vendor ID from a given device path under `/sys/bus/pci/devices`.
274fn parse_pci_device_ids(device_path: &Path) -> Result<(u32, u32), AcceleratorError> {
275    // Parse, e.g.:
276    // ```text
277    // - `class`: a hexadecimal string such as `0x030000`
278    // - `vendor`: a hexadecimal string such as `0x8086`
279    // ```
280    let class_content = fs_err::read_to_string(device_path.join("class"))?;
281    let pci_class = u32::from_str_radix(class_content.trim().trim_start_matches("0x"), 16)?;
282
283    let vendor_content = fs_err::read_to_string(device_path.join("vendor"))?;
284    let pci_vendor = u32::from_str_radix(vendor_content.trim().trim_start_matches("0x"), 16)?;
285
286    Ok((pci_class, pci_vendor))
287}
288
289/// A GPU architecture for AMD GPUs.
290///
291/// See: <https://rocm.docs.amd.com/projects/install-on-linux/en/latest/reference/system-requirements.html>
292#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd)]
293pub enum AmdGpuArchitecture {
294    Gfx900,
295    Gfx906,
296    Gfx908,
297    Gfx90a,
298    Gfx942,
299    Gfx1030,
300    Gfx1100,
301    Gfx1101,
302    Gfx1102,
303    Gfx1200,
304    Gfx1201,
305}
306
307impl FromStr for AmdGpuArchitecture {
308    type Err = AcceleratorError;
309
310    fn from_str(s: &str) -> Result<Self, Self::Err> {
311        match s {
312            "gfx900" => Ok(Self::Gfx900),
313            "gfx906" => Ok(Self::Gfx906),
314            "gfx908" => Ok(Self::Gfx908),
315            "gfx90a" => Ok(Self::Gfx90a),
316            "gfx942" => Ok(Self::Gfx942),
317            "gfx1030" => Ok(Self::Gfx1030),
318            "gfx1100" => Ok(Self::Gfx1100),
319            "gfx1101" => Ok(Self::Gfx1101),
320            "gfx1102" => Ok(Self::Gfx1102),
321            "gfx1200" => Ok(Self::Gfx1200),
322            "gfx1201" => Ok(Self::Gfx1201),
323            _ => Err(AcceleratorError::UnknownAmdGpuArchitecture(s.to_string())),
324        }
325    }
326}
327
328impl std::fmt::Display for AmdGpuArchitecture {
329    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
330        match self {
331            Self::Gfx900 => write!(f, "gfx900"),
332            Self::Gfx906 => write!(f, "gfx906"),
333            Self::Gfx908 => write!(f, "gfx908"),
334            Self::Gfx90a => write!(f, "gfx90a"),
335            Self::Gfx942 => write!(f, "gfx942"),
336            Self::Gfx1030 => write!(f, "gfx1030"),
337            Self::Gfx1100 => write!(f, "gfx1100"),
338            Self::Gfx1101 => write!(f, "gfx1101"),
339            Self::Gfx1102 => write!(f, "gfx1102"),
340            Self::Gfx1200 => write!(f, "gfx1200"),
341            Self::Gfx1201 => write!(f, "gfx1201"),
342        }
343    }
344}
345
346#[cfg(test)]
347mod tests {
348    use super::*;
349
350    #[test]
351    fn proc_driver_nvidia_version() {
352        let content = "NVRM version: NVIDIA UNIX Open Kernel Module for x86_64  550.144.03  Release Build  (dvs-builder@U16-I3-D08-1-2)  Mon Dec 30 17:26:13 UTC 2024\nGCC version:  gcc version 12.3.0 (Ubuntu 12.3.0-1ubuntu1~22.04)";
353        let result = parse_proc_driver_nvidia_version(content).unwrap();
354        assert_eq!(result, Some(Version::from_str("550.144.03").unwrap()));
355
356        let content = "NVRM version: NVIDIA UNIX x86_64 Kernel Module  375.74  Wed Jun 14 01:39:39 PDT 2017\nGCC version:  gcc version 5.4.0 20160609 (Ubuntu 5.4.0-6ubuntu1~16.04.4)";
357        let result = parse_proc_driver_nvidia_version(content).unwrap();
358        assert_eq!(result, Some(Version::from_str("375.74").unwrap()));
359    }
360
361    #[test]
362    fn nvidia_smi_multi_gpu() {
363        // Test that we can parse nvidia-smi output with multiple GPUs (multiple lines)
364        let single_gpu = "572.60\n";
365        if let Some(first_line) = single_gpu.lines().next() {
366            let version = Version::from_str(first_line.trim()).unwrap();
367            assert_eq!(version, Version::from_str("572.60").unwrap());
368        }
369
370        let multi_gpu = "572.60\n572.60\n";
371        if let Some(first_line) = multi_gpu.lines().next() {
372            let version = Version::from_str(first_line.trim()).unwrap();
373            assert_eq!(version, Version::from_str("572.60").unwrap());
374        }
375    }
376}