Skip to main content

zlayer_agent/
gpu_detector.rs

1//! GPU inventory detection
2//!
3//! Platform-specific GPU detection:
4//! - **Linux**: Scans `/sys/bus/pci/devices` for display controllers (VGA and 3D controllers).
5//!   Identifies vendor (NVIDIA, AMD, Intel) by PCI vendor ID, reads VRAM from PCI BAR regions,
6//!   and optionally uses `nvidia-smi` for NVIDIA-specific model and memory information.
7//! - **macOS**: Uses `system_profiler SPDisplaysDataType -json` to detect Apple Silicon GPUs
8//!   and unified memory via `sysctl -n hw.memsize`.
9//! - **Windows**: Layers NVML (via `nvml-wrapper`, loads `nvml.dll` at runtime) on top of
10//!   WMI `Win32_VideoController` enumeration (via the `wmi` crate). AMD VRAM is corrected by
11//!   reading `HardwareInformation.qwMemorySize` from the display-class registry subtree
12//!   (`HKLM\SYSTEM\CurrentControlSet\Control\Class\{4d36e968-...}\<xxxx>`) because WMI's
13//!   `AdapterRAM` is a `u32` capped at 4 GiB. NVML data is preferred when both surface the
14//!   same card.
15//! - **Other**: Returns an empty GPU list.
16//!
17//! Linux/macOS require no external crates -- pure `sysfs/system_profiler` scanning with
18//! optional subprocess calls. Windows pulls in `nvml-wrapper`, `wmi`, and `windows-registry`
19//! (all gated to `target_os = "windows"` in the crate manifest).
20
21use serde::{Deserialize, Serialize};
22
23/// Detected GPU information
24#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
25pub struct GpuInfo {
26    /// PCI bus ID (e.g., "0000:01:00.0" on Linux, "apple:0" on macOS)
27    pub pci_bus_id: String,
28    /// Vendor: "nvidia", "amd", "intel", "apple", or "unknown"
29    pub vendor: String,
30    /// Model name (e.g., "Apple M2 Pro" or "NVIDIA A100-SXM4-80GB")
31    pub model: String,
32    /// VRAM in MB (0 if unknown; on Apple Silicon, this is unified memory)
33    pub memory_mb: u64,
34    /// Device path (e.g., "/dev/nvidia0", "/dev/dri/card0", "<iokit://AppleGPU/0>")
35    pub device_path: String,
36    /// Render node path if applicable (e.g., "/dev/dri/renderD128"); None on macOS
37    pub render_path: Option<String>,
38}
39
40// =============================================================================
41// Linux GPU detection
42// =============================================================================
43
44/// Scan the system for GPU devices via sysfs PCI enumeration (Linux only)
45///
46/// Iterates over `/sys/bus/pci/devices` looking for PCI class codes that
47/// indicate display controllers:
48/// - `0x0300xx` -- VGA compatible controller
49/// - `0x0302xx` -- 3D controller (e.g., NVIDIA Tesla/datacenter GPUs)
50///
51/// For each GPU found, determines vendor, model name, VRAM, and device paths.
52#[cfg(target_os = "linux")]
53#[must_use]
54pub fn detect_gpus() -> Vec<GpuInfo> {
55    use std::path::Path;
56
57    let mut gpus = Vec::new();
58
59    let pci_dir = Path::new("/sys/bus/pci/devices");
60    if !pci_dir.exists() {
61        return gpus;
62    }
63
64    let Ok(entries) = std::fs::read_dir(pci_dir) else {
65        return gpus;
66    };
67
68    // Optionally pre-fetch nvidia-smi data once for all NVIDIA GPUs
69    let nvidia_data = NvidiaSmiData::fetch();
70
71    for entry in entries.flatten() {
72        let device_dir = entry.path();
73
74        // Read PCI device class
75        let class_path = device_dir.join("class");
76        let class = match std::fs::read_to_string(&class_path) {
77            Ok(c) => c.trim().to_string(),
78            Err(_) => continue,
79        };
80
81        // Filter to display controllers only
82        if !class.starts_with("0x0302") && !class.starts_with("0x0300") {
83            continue;
84        }
85
86        // Read PCI vendor ID
87        let vendor_path = device_dir.join("vendor");
88        let vendor_id = std::fs::read_to_string(&vendor_path)
89            .unwrap_or_default()
90            .trim()
91            .to_string();
92
93        let vendor = match vendor_id.as_str() {
94            "0x10de" => "nvidia",
95            "0x1002" => "amd",
96            "0x8086" => "intel",
97            _ => "unknown",
98        }
99        .to_string();
100
101        let pci_bus_id = entry.file_name().to_string_lossy().to_string();
102
103        // Count how many GPUs of this vendor we've already seen (for device path indexing)
104        let vendor_index = gpus
105            .iter()
106            .filter(|g: &&GpuInfo| g.vendor == vendor)
107            .count();
108
109        let model = read_gpu_model(&device_dir, &vendor, &nvidia_data, vendor_index);
110        let memory_mb = read_gpu_memory(&device_dir, &vendor, &nvidia_data, vendor_index);
111        let (device_path, render_path) = find_device_paths(&pci_bus_id, &vendor, vendor_index);
112
113        gpus.push(GpuInfo {
114            pci_bus_id,
115            vendor,
116            model,
117            memory_mb,
118            device_path,
119            render_path,
120        });
121    }
122
123    gpus
124}
125
126// =============================================================================
127// macOS GPU detection
128// =============================================================================
129
130/// Detect Apple Silicon GPUs via `system_profiler` (macOS only)
131///
132/// Runs `system_profiler SPDisplaysDataType -json` to enumerate GPUs, then
133/// queries `sysctl -n hw.memsize` for the unified memory pool size. Apple Silicon
134/// shares system memory between CPU and GPU, so the full physical memory is
135/// reported as the GPU's available memory.
136#[cfg(target_os = "macos")]
137#[must_use]
138pub fn detect_gpus() -> Vec<GpuInfo> {
139    detect_apple_gpus()
140}
141
142/// Internal macOS GPU detection implementation
143#[cfg(target_os = "macos")]
144fn detect_apple_gpus() -> Vec<GpuInfo> {
145    let output = match std::process::Command::new("system_profiler")
146        .args(["SPDisplaysDataType", "-json"])
147        .output()
148    {
149        Ok(out) if out.status.success() => out,
150        _ => return Vec::new(),
151    };
152
153    let json_str = String::from_utf8_lossy(&output.stdout);
154    let parsed: serde_json::Value = match serde_json::from_str(&json_str) {
155        Ok(v) => v,
156        Err(_) => return Vec::new(),
157    };
158
159    let unified_memory_mb = detect_unified_memory_mb();
160
161    let mut gpus = Vec::new();
162
163    // system_profiler returns { "SPDisplaysDataType": [ { ... }, ... ] }
164    let Some(displays) = parsed.get("SPDisplaysDataType").and_then(|v| v.as_array()) else {
165        return gpus;
166    };
167
168    for (idx, display) in displays.iter().enumerate() {
169        let model = display
170            .get("sppci_model")
171            .and_then(|v| v.as_str())
172            .or_else(|| display.get("_name").and_then(|v| v.as_str()))
173            .unwrap_or("Apple GPU")
174            .to_string();
175
176        let chip_type = display
177            .get("sppci_chiptype")
178            .and_then(|v| v.as_str())
179            .unwrap_or("");
180
181        // Use chip type in model name if the model doesn't already include it
182        let model = if !chip_type.is_empty() && !model.contains(chip_type) {
183            format!("{model} ({chip_type})")
184        } else {
185            model
186        };
187
188        // Apple Silicon uses unified memory -- report the full system memory
189        // as GPU-accessible memory. For discrete AMD GPUs in older Macs,
190        // try to read the VRAM field from system_profiler.
191        let memory_mb = display
192            .get("sppci_vram")
193            .and_then(|v| v.as_str())
194            .and_then(|s| {
195                // Format is like "16 GB" or "8192 MB"
196                let parts: Vec<&str> = s.split_whitespace().collect();
197                if parts.len() >= 2 {
198                    let amount: u64 = parts[0].parse().ok()?;
199                    match parts[1].to_uppercase().as_str() {
200                        "GB" => Some(amount * 1024),
201                        "MB" => Some(amount),
202                        _ => None,
203                    }
204                } else {
205                    None
206                }
207            })
208            .unwrap_or(unified_memory_mb);
209
210        let vendor_str = display
211            .get("sppci_vendor")
212            .and_then(|v| v.as_str())
213            .unwrap_or("");
214
215        // Determine vendor from vendor string or chip type
216        let vendor = if vendor_str.to_lowercase().contains("apple")
217            || chip_type.to_lowercase().starts_with("apple")
218            || model.to_lowercase().contains("apple m")
219        {
220            "apple".to_string()
221        } else if vendor_str.to_lowercase().contains("amd")
222            || vendor_str.to_lowercase().contains("ati")
223        {
224            "amd".to_string()
225        } else if vendor_str.to_lowercase().contains("intel") {
226            "intel".to_string()
227        } else {
228            // Default to "apple" on macOS when vendor is ambiguous
229            "apple".to_string()
230        };
231
232        gpus.push(GpuInfo {
233            pci_bus_id: format!("apple:{idx}"),
234            vendor,
235            model,
236            memory_mb,
237            device_path: format!("iokit://AppleGPU/{idx}"),
238            render_path: None,
239        });
240    }
241
242    gpus
243}
244
245/// Query unified memory size via sysctl on macOS
246#[cfg(target_os = "macos")]
247fn detect_unified_memory_mb() -> u64 {
248    let output = match std::process::Command::new("sysctl")
249        .args(["-n", "hw.memsize"])
250        .output()
251    {
252        Ok(out) if out.status.success() => out,
253        _ => return 0,
254    };
255
256    let text = String::from_utf8_lossy(&output.stdout);
257    text.trim()
258        .parse::<u64>()
259        .map(|bytes| bytes / (1024 * 1024))
260        .unwrap_or(0)
261}
262
263// =============================================================================
264// Windows GPU detection
265// =============================================================================
266
267/// Detect GPUs on Windows via NVML + WMI + registry
268///
269/// Runs three layered probes:
270/// 1. **NVML** (`nvml-wrapper`): loads `nvml.dll` via libloading; cleanly returns `Err`
271///    when the NVIDIA driver is absent, which we silence and move on.
272/// 2. **WMI** (`Win32_VideoController`): enumerates every GPU `PnP` device. Filters
273///    `PNPDeviceID LIKE 'PCI\\VEN_%'` to exclude Remote Desktop / Hyper-V synthetic
274///    adapters. Extracts PCI vendor + device IDs from `PNPDeviceID`.
275/// 3. **Registry** (`HKLM\SYSTEM\...\Class\{4d36e968-...}\<xxxx>`): for AMD cards
276///    we read `HardwareInformation.qwMemorySize` (u64) because WMI's `AdapterRAM`
277///    is a 32-bit field capped at 4 GiB and lies for modern cards.
278///
279/// When NVML and WMI surface the same NVIDIA card (matched by PCI bus ID), NVML wins
280/// for VRAM (accurate) and WMI wins for model-name enrichment when NVML returned an
281/// empty/placeholder name.
282#[cfg(target_os = "windows")]
283#[must_use]
284pub fn detect_gpus() -> Vec<GpuInfo> {
285    windows_impl::detect_gpus_windows()
286}
287
288// =============================================================================
289// Fallback for unsupported platforms
290// =============================================================================
291
292/// Returns an empty GPU list on unsupported platforms
293#[cfg(not(any(target_os = "linux", target_os = "macos", target_os = "windows")))]
294#[must_use]
295pub fn detect_gpus() -> Vec<GpuInfo> {
296    Vec::new()
297}
298
299// =============================================================================
300// Linux-only helpers: nvidia-smi, sysfs scanning
301// =============================================================================
302
303/// Pre-fetched nvidia-smi data to avoid calling the subprocess multiple times
304#[cfg(target_os = "linux")]
305struct NvidiaSmiData {
306    /// GPU names, one per line
307    names: Vec<String>,
308    /// GPU memory in MB, one per line
309    memories: Vec<u64>,
310}
311
312#[cfg(target_os = "linux")]
313impl NvidiaSmiData {
314    /// Attempt to fetch GPU info from nvidia-smi. Returns empty data on failure.
315    fn fetch() -> Self {
316        let names = Self::query("name");
317        let memories = Self::query("memory.total")
318            .iter()
319            .map(|s| s.trim().parse::<u64>().unwrap_or(0))
320            .collect();
321
322        Self { names, memories }
323    }
324
325    fn query(field: &str) -> Vec<String> {
326        let output = std::process::Command::new("nvidia-smi")
327            .args([
328                &format!("--query-gpu={field}"),
329                "--format=csv,noheader,nounits",
330            ])
331            .output();
332
333        match output {
334            Ok(out) if out.status.success() => {
335                let text = String::from_utf8_lossy(&out.stdout);
336                text.lines().map(|l| l.trim().to_string()).collect()
337            }
338            _ => Vec::new(),
339        }
340    }
341}
342
343// =============================================================================
344// Model detection (Linux only)
345// =============================================================================
346
347/// Read GPU model name from sysfs or nvidia-smi
348#[cfg(target_os = "linux")]
349fn read_gpu_model(
350    device_dir: &std::path::Path,
351    vendor: &str,
352    nvidia_data: &NvidiaSmiData,
353    vendor_index: usize,
354) -> String {
355    // Try DRM subsystem product name first (works for all vendors on recent kernels)
356    if let Some(name) = read_drm_product_name(device_dir) {
357        return name;
358    }
359
360    match vendor {
361        "nvidia" => {
362            // Use pre-fetched nvidia-smi data
363            if let Some(name) = nvidia_data.names.get(vendor_index) {
364                if !name.is_empty() {
365                    return name.clone();
366                }
367            }
368            "NVIDIA GPU".to_string()
369        }
370        "amd" => "AMD GPU".to_string(),
371        "intel" => "Intel GPU".to_string(),
372        _ => "Unknown GPU".to_string(),
373    }
374}
375
376/// Try to read GPU product name from the DRM subsystem
377///
378/// Checks `/sys/bus/pci/devices/XXXX/drm/cardN/device/product_name` and similar paths.
379#[cfg(target_os = "linux")]
380fn read_drm_product_name(device_dir: &std::path::Path) -> Option<String> {
381    // Try the product_name file under the PCI device
382    let product_name_path = device_dir.join("label");
383    if let Ok(name) = std::fs::read_to_string(&product_name_path) {
384        let name = name.trim().to_string();
385        if !name.is_empty() {
386            return Some(name);
387        }
388    }
389
390    // Try reading from the DRM card's device directory
391    let drm_dir = device_dir.join("drm");
392    if let Ok(entries) = std::fs::read_dir(&drm_dir) {
393        for entry in entries.flatten() {
394            let name = entry.file_name();
395            let name_str = name.to_string_lossy();
396            if name_str.starts_with("card") {
397                let product_path = entry.path().join("device").join("product_name");
398                if let Ok(product) = std::fs::read_to_string(&product_path) {
399                    let product = product.trim().to_string();
400                    if !product.is_empty() {
401                        return Some(product);
402                    }
403                }
404            }
405        }
406    }
407
408    None
409}
410
411// =============================================================================
412// VRAM detection (Linux only)
413// =============================================================================
414
415/// Read GPU VRAM from sysfs PCI BAR regions or nvidia-smi
416#[cfg(target_os = "linux")]
417fn read_gpu_memory(
418    device_dir: &std::path::Path,
419    vendor: &str,
420    nvidia_data: &NvidiaSmiData,
421    vendor_index: usize,
422) -> u64 {
423    // For NVIDIA, prefer nvidia-smi data (more accurate than PCI BAR)
424    if vendor == "nvidia" {
425        if let Some(&mem) = nvidia_data.memories.get(vendor_index) {
426            if mem > 0 {
427                return mem;
428            }
429        }
430    }
431
432    // For AMD, try the VRAM-specific sysfs file
433    if vendor == "amd" {
434        let vram_path = device_dir.join("mem_info_vram_total");
435        if let Ok(content) = std::fs::read_to_string(&vram_path) {
436            if let Ok(bytes) = content.trim().parse::<u64>() {
437                return bytes / (1024 * 1024);
438            }
439        }
440    }
441
442    // Fall back to reading PCI resource file for BAR sizes
443    // The largest BAR region is typically VRAM
444    let resource_path = device_dir.join("resource");
445    if let Ok(content) = std::fs::read_to_string(&resource_path) {
446        let mut max_size: u64 = 0;
447        for line in content.lines() {
448            let parts: Vec<&str> = line.split_whitespace().collect();
449            if parts.len() >= 2 {
450                if let (Ok(start), Ok(end)) = (
451                    u64::from_str_radix(parts[0].trim_start_matches("0x"), 16),
452                    u64::from_str_radix(parts[1].trim_start_matches("0x"), 16),
453                ) {
454                    if end > start {
455                        let size = end - start + 1;
456                        if size > max_size {
457                            max_size = size;
458                        }
459                    }
460                }
461            }
462        }
463        if max_size > 0 {
464            return max_size / (1024 * 1024);
465        }
466    }
467
468    0
469}
470
471// =============================================================================
472// Device path resolution (Linux only)
473// =============================================================================
474
475/// Find device paths for a GPU based on vendor and index
476#[cfg(target_os = "linux")]
477fn find_device_paths(
478    _pci_bus_id: &str,
479    vendor: &str,
480    vendor_index: usize,
481) -> (String, Option<String>) {
482    if vendor == "nvidia" {
483        let dev = format!("/dev/nvidia{vendor_index}");
484        (dev, None)
485    } else {
486        // AMD, Intel, and unknown vendors use DRI device nodes
487        let card = format!("/dev/dri/card{vendor_index}");
488        let render = format!("/dev/dri/renderD{}", 128 + vendor_index);
489        (card, Some(render))
490    }
491}
492
493// =============================================================================
494// Windows implementation module
495// =============================================================================
496
497#[cfg(target_os = "windows")]
498mod windows_impl {
499    use super::GpuInfo;
500    use std::collections::HashMap;
501    use wmi::{Variant, WMIConnection};
502
503    /// Display-adapter class GUID. Every `DISPLAY` driver instance registers
504    /// itself under `HKLM\SYSTEM\CurrentControlSet\Control\Class\{4d36e968-...}`
505    /// with a 4-digit index (`0000`, `0001`, ...). AMD's driver writes
506    /// `HardwareInformation.qwMemorySize` (`REG_QWORD`, bytes) there — this is the
507    /// authoritative VRAM size, not WMI's 32-bit-capped `AdapterRAM`.
508    const DISPLAY_CLASS_GUID: &str = "{4d36e968-e325-11ce-bfc1-08002be10318}";
509
510    /// Entry point for the `windows` target's `detect_gpus()`.
511    pub fn detect_gpus_windows() -> Vec<GpuInfo> {
512        let mut gpus: Vec<GpuInfo> = Vec::new();
513
514        // --- Pass 1: NVML (best-effort; absent driver is not fatal) ---------
515        let nvml_gpus = detect_via_nvml();
516        gpus.extend(nvml_gpus);
517
518        // --- Pass 2: WMI Win32_VideoController ------------------------------
519        let wmi_gpus = match detect_via_wmi() {
520            Ok(v) => v,
521            Err(e) => {
522                tracing::warn!(
523                    "WMI Win32_VideoController query failed: {e}; \
524                     Windows GPU detection falling back to NVML-only"
525                );
526                Vec::new()
527            }
528        };
529
530        // --- Pass 3: Registry VRAM correction for AMD, and dedupe ----------
531        let amd_registry = collect_amd_registry_vram().unwrap_or_default();
532
533        for mut wmi_gpu in wmi_gpus {
534            // AMD: replace WMI's AdapterRAM with qwMemorySize if we have it.
535            if wmi_gpu.vendor == "amd" {
536                if let Some(key) = pci_key_from_bus_id(&wmi_gpu.pci_bus_id) {
537                    if let Some(&vram_bytes) = amd_registry.get(&key) {
538                        wmi_gpu.memory_mb = vram_bytes / (1024 * 1024);
539                    }
540                }
541            }
542
543            // Dedupe against NVML entries (same PCI bus id).
544            if let Some(existing) = gpus.iter_mut().find(|g| g.pci_bus_id == wmi_gpu.pci_bus_id) {
545                // NVML wins for VRAM (accurate), but let WMI enrich an empty
546                // / placeholder NVML model name.
547                if existing.model.trim().is_empty() || existing.model == "NVIDIA GPU" {
548                    wmi_gpu.model.clone_into(&mut existing.model);
549                }
550                continue;
551            }
552
553            gpus.push(wmi_gpu);
554        }
555
556        gpus
557    }
558
559    // ------------------------------------------------------------------------
560    // NVML probe
561    // ------------------------------------------------------------------------
562
563    fn detect_via_nvml() -> Vec<GpuInfo> {
564        // `Nvml::init()` dlopens `nvml.dll`. When the NVIDIA driver isn't
565        // installed, this returns a clean `Err` — treat that as "no NVIDIA
566        // GPUs" and move on without polluting logs.
567        let nvml = match nvml_wrapper::Nvml::init() {
568            Ok(n) => n,
569            Err(e) => {
570                tracing::debug!("NVML unavailable (no NVIDIA driver?): {e}");
571                return Vec::new();
572            }
573        };
574
575        let count = match nvml.device_count() {
576            Ok(c) => c,
577            Err(e) => {
578                tracing::debug!("nvmlDeviceGetCount failed: {e}");
579                return Vec::new();
580            }
581        };
582
583        let mut out = Vec::with_capacity(count as usize);
584        for i in 0..count {
585            let device = match nvml.device_by_index(i) {
586                Ok(d) => d,
587                Err(e) => {
588                    tracing::debug!("nvmlDeviceGetHandleByIndex({i}) failed: {e}");
589                    continue;
590                }
591            };
592
593            let model = device.name().unwrap_or_else(|_| "NVIDIA GPU".to_string());
594
595            // `memory_info()` returns bytes.
596            let memory_mb = device
597                .memory_info()
598                .map(|m| m.total / (1024 * 1024))
599                .unwrap_or(0);
600
601            let pci_bus_id = match device.pci_info() {
602                Ok(info) => canonicalize_pci_bus_id(&info.bus_id),
603                Err(e) => {
604                    tracing::debug!("nvmlDeviceGetPciInfo({i}) failed: {e}");
605                    // Without a PCI bus id we can't dedupe against WMI, so
606                    // fall back to an index-based id.
607                    format!("nvml:{i}")
608                }
609            };
610
611            out.push(GpuInfo {
612                pci_bus_id,
613                vendor: "nvidia".to_string(),
614                model,
615                memory_mb,
616                // Windows doesn't expose /dev nodes — report the NVML handle
617                // index so downstream consumers can map back to the GPU.
618                device_path: format!("nvml://{i}"),
619                render_path: None,
620            });
621        }
622
623        out
624    }
625
626    /// NVML returns a PCI bus ID like `00000000:01:00.0` (8-char domain).
627    /// sysfs / Win32 tooling conventionally uses 4-char domain (`0000:01:00.0`),
628    /// so trim the leading zeros in the domain segment to keep our bus IDs
629    /// consistent across probes (NVML vs WMI registry parsing).
630    fn canonicalize_pci_bus_id(raw: &str) -> String {
631        let mut parts = raw.splitn(2, ':');
632        let (Some(domain), Some(rest)) = (parts.next(), parts.next()) else {
633            return raw.to_ascii_lowercase();
634        };
635        let trimmed_domain = domain.trim_start_matches('0').to_string();
636        let domain_out = if trimmed_domain.is_empty() {
637            "0000".to_string()
638        } else if trimmed_domain.len() < 4 {
639            format!("{trimmed_domain:0>4}")
640        } else {
641            trimmed_domain
642        };
643        format!("{domain_out}:{rest}").to_ascii_lowercase()
644    }
645
646    // ------------------------------------------------------------------------
647    // WMI probe
648    // ------------------------------------------------------------------------
649
650    fn detect_via_wmi() -> Result<Vec<GpuInfo>, String> {
651        // `wmi::WMIConnection::new` defaults to the `ROOT\CIMV2` namespace and
652        // auto-initializes COM via `CoIncrementMTAUsage` when COM hasn't been
653        // initialized on the current thread. No explicit `COMLibrary` needed
654        // in wmi 0.18 — the reffcount is tracked internally.
655        let wmi = WMIConnection::new().map_err(|e| format!("WMIConnection::new: {e}"))?;
656
657        // `PNPDeviceID LIKE 'PCI\\VEN_%'` filters out Remote Desktop Mirror
658        // drivers, Hyper-V synthetic adapters, and software renderers — only
659        // real PCI display controllers match.
660        let query = "SELECT Name, PNPDeviceID, AdapterRAM \
661                     FROM Win32_VideoController \
662                     WHERE PNPDeviceID LIKE 'PCI\\\\VEN_%'";
663
664        let rows: Vec<HashMap<String, Variant>> = wmi
665            .raw_query(query)
666            .map_err(|e| format!("raw_query({query}): {e}"))?;
667
668        let mut out = Vec::with_capacity(rows.len());
669        for row in rows {
670            let Some(pnp) = variant_string(row.get("PNPDeviceID")) else {
671                continue;
672            };
673            let Some((vendor_id, device_id)) = parse_ven_dev(&pnp) else {
674                continue;
675            };
676
677            let vendor = vendor_from_pci_id(vendor_id);
678            let model = variant_string(row.get("Name"))
679                .unwrap_or_else(|| format!("{} GPU", vendor.to_ascii_uppercase()));
680
681            // AdapterRAM is a uint32 in bytes (CIM_UINT32), so WMI surfaces it
682            // as Variant::UI4 / I4. Cap-at-4-GiB is the known lie we correct
683            // below for AMD via the registry.
684            let memory_mb = variant_u64(row.get("AdapterRAM")).map_or(0, |b| b / (1024 * 1024));
685
686            let pci_bus_id = pci_bus_id_from_pnp(&pnp, vendor_id, device_id);
687
688            out.push(GpuInfo {
689                pci_bus_id,
690                vendor: vendor.to_string(),
691                model,
692                memory_mb,
693                device_path: format!(r"\\.\DISPLAY#{pnp}"),
694                render_path: None,
695            });
696        }
697
698        Ok(out)
699    }
700
701    /// Extract `(vendor_id, device_id)` from a `PNPDeviceID` like
702    /// `PCI\VEN_10DE&DEV_2204&SUBSYS_...&REV_A1\4&31DE5EF7&0&0008`.
703    fn parse_ven_dev(pnp: &str) -> Option<(u16, u16)> {
704        // Matches case-insensitively. The PnP id format is stable across
705        // Windows versions since XP.
706        let upper = pnp.to_ascii_uppercase();
707        let ven = extract_hex(&upper, "VEN_", 4)?;
708        let dev = extract_hex(&upper, "DEV_", 4)?;
709        Some((ven, dev))
710    }
711
712    fn extract_hex(s: &str, marker: &str, nibbles: usize) -> Option<u16> {
713        let start = s.find(marker)? + marker.len();
714        let hex = s.get(start..start + nibbles)?;
715        u16::from_str_radix(hex, 16).ok()
716    }
717
718    fn vendor_from_pci_id(vendor_id: u16) -> &'static str {
719        match vendor_id {
720            0x10DE => "nvidia",
721            // 0x1002 = AMD discrete (ATI); 0x1022 = AMD APU IGP.
722            0x1002 | 0x1022 => "amd",
723            0x8086 => "intel",
724            _ => "unknown",
725        }
726    }
727
728    /// Best-effort synthetic PCI bus ID derived from the `PnP` path. Windows
729    /// doesn't hand us the bus:device.function triple directly in the `PnP`
730    /// string, so we fall back to `0000:<ven>:<dev>.0` — consistent across
731    /// NVML-vs-WMI dedupe runs because NVML's `bus_id` is different shape.
732    ///
733    /// NVML's `pci_info().bus_id` has the real bus:device.function, so for
734    /// NVIDIA GPUs present in both probes we always prefer the NVML entry
735    /// (dedup runs in `detect_gpus_windows` by NVML-side bus id). WMI-only
736    /// entries (AMD, Intel) use this synthetic form, which remains stable per
737    /// host for registry matching.
738    fn pci_bus_id_from_pnp(pnp: &str, vendor_id: u16, device_id: u16) -> String {
739        // Try to pull the instance-id tail (`&0008` above → `0008`) as a weak
740        // "slot" signal so multiple same-vendor cards on one host get distinct
741        // ids. Falls back to `0000`.
742        let slot = pnp
743            .rsplit_once('&')
744            .and_then(|(_, tail)| tail.chars().take(4).collect::<String>().parse::<u16>().ok())
745            .unwrap_or(0);
746        format!("0000:{vendor_id:04x}:{device_id:04x}.{slot:x}")
747    }
748
749    fn variant_string(v: Option<&Variant>) -> Option<String> {
750        match v? {
751            Variant::String(s) => Some(s.clone()),
752            _ => None,
753        }
754    }
755
756    fn variant_u64(v: Option<&Variant>) -> Option<u64> {
757        match v? {
758            Variant::UI1(n) => Some(u64::from(*n)),
759            Variant::UI2(n) => Some(u64::from(*n)),
760            Variant::UI4(n) => Some(u64::from(*n)),
761            Variant::UI8(n) => Some(*n),
762            Variant::I1(n) if *n >= 0 => Some(u64::try_from(*n).unwrap_or(0)),
763            Variant::I2(n) if *n >= 0 => Some(u64::try_from(*n).unwrap_or(0)),
764            Variant::I4(n) if *n >= 0 => Some(u64::try_from(*n).unwrap_or(0)),
765            Variant::I8(n) if *n >= 0 => Some(u64::try_from(*n).unwrap_or(0)),
766            _ => None,
767        }
768    }
769
770    // ------------------------------------------------------------------------
771    // Registry probe (AMD VRAM correction)
772    // ------------------------------------------------------------------------
773
774    /// Walk `HKLM\SYSTEM\CurrentControlSet\Control\Class\{4d36e968-...}` for
775    /// AMD display-adapter subkeys and collect their `qwMemorySize` values.
776    /// Returns a map from `(vendor_id, device_id)` → VRAM-in-bytes.
777    fn collect_amd_registry_vram() -> Result<HashMap<(u16, u16), u64>, String> {
778        let class_path = format!(r"SYSTEM\CurrentControlSet\Control\Class\{DISPLAY_CLASS_GUID}");
779        let class_key = windows_registry::LOCAL_MACHINE
780            .open(&class_path)
781            .map_err(|e| format!("open HKLM\\{class_path}: {e}"))?;
782
783        let mut out: HashMap<(u16, u16), u64> = HashMap::new();
784
785        // Subkeys are 4-digit instance indices: 0000, 0001, ...
786        // Non-4-digit subkeys (Properties, Configuration) are harmless to
787        // attempt — they just fail the open/value reads and we move on.
788        let subkeys = match class_key.keys() {
789            Ok(it) => it,
790            Err(e) => return Err(format!("enumerate class subkeys: {e}")),
791        };
792
793        for name in subkeys {
794            // Skip obvious non-adapter subkeys quickly; still allow anything
795            // that parses as a 4-digit hex-ish index.
796            if name.len() != 4 || !name.chars().all(|c| c.is_ascii_digit()) {
797                continue;
798            }
799
800            let Ok(adapter_key) = class_key.open(&name) else {
801                continue;
802            };
803
804            let vendor_id = adapter_key
805                .get_string("MatchingDeviceId")
806                .ok()
807                .as_deref()
808                .and_then(parse_matching_device_id);
809            let Some((ven, dev)) = vendor_id else {
810                continue;
811            };
812
813            if vendor_from_pci_id(ven) != "amd" {
814                continue;
815            }
816
817            // qwMemorySize is a REG_QWORD written by modern AMD drivers.
818            // Older drivers may only publish `MemorySize` (REG_DWORD) — but
819            // that hits the same 32-bit-cap problem as WMI, so we skip it
820            // rather than report a wrong number.
821            if let Ok(bytes) = adapter_key
822                .open("HardwareInformation")
823                .and_then(|hw| hw.get_u64("qwMemorySize"))
824            {
825                out.insert((ven, dev), bytes);
826            }
827        }
828
829        Ok(out)
830    }
831
832    /// `MatchingDeviceId` is formatted like `PCI\VEN_1002&DEV_73A5&SUBSYS_...`.
833    /// Reuse the `PnP` parser.
834    fn parse_matching_device_id(s: &str) -> Option<(u16, u16)> {
835        parse_ven_dev(s)
836    }
837
838    /// Build the AMD-registry lookup key from a WMI bus id (`0000:1002:73a5.0`).
839    fn pci_key_from_bus_id(bus_id: &str) -> Option<(u16, u16)> {
840        let mut parts = bus_id.split(':');
841        let _domain = parts.next()?;
842        let ven = u16::from_str_radix(parts.next()?, 16).ok()?;
843        let dev_fn = parts.next()?;
844        let dev = dev_fn
845            .split('.')
846            .next()
847            .and_then(|h| u16::from_str_radix(h, 16).ok())?;
848        Some((ven, dev))
849    }
850
851    // ------------------------------------------------------------------------
852    // Unit tests for pure helpers (no Windows APIs touched)
853    // ------------------------------------------------------------------------
854
855    #[cfg(test)]
856    mod tests {
857        use super::*;
858
859        #[test]
860        fn parse_ven_dev_nvidia() {
861            let pnp = r"PCI\VEN_10DE&DEV_2204&SUBSYS_38811462&REV_A1\4&31DE5EF7&0&0008";
862            assert_eq!(parse_ven_dev(pnp), Some((0x10DE, 0x2204)));
863        }
864
865        #[test]
866        fn parse_ven_dev_amd() {
867            let pnp = r"PCI\VEN_1002&DEV_73A5&SUBSYS_E4571DA2&REV_C0\4&1A2B3C4D&0&0010";
868            assert_eq!(parse_ven_dev(pnp), Some((0x1002, 0x73A5)));
869        }
870
871        #[test]
872        fn parse_ven_dev_intel_lowercase() {
873            let pnp = r"pci\ven_8086&dev_9a49&subsys_00000000&rev_01\3&11583659&0&10";
874            assert_eq!(parse_ven_dev(pnp), Some((0x8086, 0x9A49)));
875        }
876
877        #[test]
878        fn parse_ven_dev_rejects_malformed() {
879            assert_eq!(parse_ven_dev("USB\\VID_1234&PID_5678"), None);
880            assert_eq!(parse_ven_dev(""), None);
881        }
882
883        #[test]
884        fn vendor_id_mapping() {
885            assert_eq!(vendor_from_pci_id(0x10DE), "nvidia");
886            assert_eq!(vendor_from_pci_id(0x1002), "amd");
887            assert_eq!(vendor_from_pci_id(0x1022), "amd");
888            assert_eq!(vendor_from_pci_id(0x8086), "intel");
889            assert_eq!(vendor_from_pci_id(0x1234), "unknown");
890        }
891
892        #[test]
893        fn canonicalize_strips_nvml_domain_padding() {
894            assert_eq!(canonicalize_pci_bus_id("00000000:01:00.0"), "0000:01:00.0");
895            assert_eq!(canonicalize_pci_bus_id("0000:17:00.0"), "0000:17:00.0");
896            assert_eq!(canonicalize_pci_bus_id("000a:17:00.0"), "000a:17:00.0");
897        }
898
899        #[test]
900        fn canonicalize_handles_missing_colon() {
901            // No colon -> lower-case passthrough (defensive).
902            assert_eq!(canonicalize_pci_bus_id("WEIRD"), "weird");
903        }
904
905        #[test]
906        fn pci_key_from_bus_id_roundtrip() {
907            let bus = pci_bus_id_from_pnp(
908                r"PCI\VEN_1002&DEV_73A5&SUBSYS_E4571DA2&REV_C0\4&1A2B3C4D&0&0010",
909                0x1002,
910                0x73A5,
911            );
912            assert_eq!(pci_key_from_bus_id(&bus), Some((0x1002, 0x73A5)));
913        }
914
915        #[test]
916        fn variant_u64_accepts_unsigned_widths() {
917            assert_eq!(variant_u64(Some(&Variant::UI4(4096))), Some(4096));
918            assert_eq!(
919                variant_u64(Some(&Variant::UI8(17_179_869_184))),
920                Some(17_179_869_184)
921            );
922            assert_eq!(variant_u64(Some(&Variant::UI1(7))), Some(7));
923        }
924
925        #[test]
926        fn variant_u64_rejects_negative_signed() {
927            assert_eq!(variant_u64(Some(&Variant::I4(-1))), None);
928        }
929
930        #[test]
931        fn variant_string_unwraps() {
932            assert_eq!(
933                variant_string(Some(&Variant::String("NVIDIA GeForce RTX 4090".into()))),
934                Some("NVIDIA GeForce RTX 4090".to_string())
935            );
936            assert_eq!(variant_string(Some(&Variant::UI4(7))), None);
937            assert_eq!(variant_string(None), None);
938        }
939    }
940}
941
942// =============================================================================
943// Tests
944// =============================================================================
945
946#[cfg(test)]
947mod tests {
948    use super::*;
949
950    #[test]
951    fn test_gpu_info_serialization_roundtrip() {
952        let info = GpuInfo {
953            pci_bus_id: "0000:01:00.0".to_string(),
954            vendor: "nvidia".to_string(),
955            model: "NVIDIA A100-SXM4-80GB".to_string(),
956            memory_mb: 81920,
957            device_path: "/dev/nvidia0".to_string(),
958            render_path: None,
959        };
960
961        let json = serde_json::to_string(&info).unwrap();
962        let deserialized: GpuInfo = serde_json::from_str(&json).unwrap();
963        assert_eq!(info, deserialized);
964    }
965
966    #[test]
967    fn test_gpu_info_amd_serialization() {
968        let info = GpuInfo {
969            pci_bus_id: "0000:03:00.0".to_string(),
970            vendor: "amd".to_string(),
971            model: "AMD GPU".to_string(),
972            memory_mb: 16384,
973            device_path: "/dev/dri/card0".to_string(),
974            render_path: Some("/dev/dri/renderD128".to_string()),
975        };
976
977        let json = serde_json::to_string(&info).unwrap();
978        let deserialized: GpuInfo = serde_json::from_str(&json).unwrap();
979        assert_eq!(info, deserialized);
980    }
981
982    #[test]
983    fn test_gpu_info_apple_serialization() {
984        let info = GpuInfo {
985            pci_bus_id: "apple:0".to_string(),
986            vendor: "apple".to_string(),
987            model: "Apple M2 Pro".to_string(),
988            memory_mb: 32768,
989            device_path: "iokit://AppleGPU/0".to_string(),
990            render_path: None,
991        };
992
993        let json = serde_json::to_string(&info).unwrap();
994        let deserialized: GpuInfo = serde_json::from_str(&json).unwrap();
995        assert_eq!(info, deserialized);
996    }
997
998    #[cfg(target_os = "linux")]
999    #[test]
1000    fn test_find_device_paths_nvidia() {
1001        let (dev, render) = find_device_paths("0000:01:00.0", "nvidia", 0);
1002        assert_eq!(dev, "/dev/nvidia0");
1003        assert!(render.is_none());
1004
1005        let (dev, render) = find_device_paths("0000:02:00.0", "nvidia", 1);
1006        assert_eq!(dev, "/dev/nvidia1");
1007        assert!(render.is_none());
1008    }
1009
1010    #[cfg(target_os = "linux")]
1011    #[test]
1012    fn test_find_device_paths_amd() {
1013        let (dev, render) = find_device_paths("0000:03:00.0", "amd", 0);
1014        assert_eq!(dev, "/dev/dri/card0");
1015        assert_eq!(render, Some("/dev/dri/renderD128".to_string()));
1016    }
1017
1018    #[cfg(target_os = "linux")]
1019    #[test]
1020    fn test_find_device_paths_intel() {
1021        let (dev, render) = find_device_paths("0000:00:02.0", "intel", 0);
1022        assert_eq!(dev, "/dev/dri/card0");
1023        assert_eq!(render, Some("/dev/dri/renderD128".to_string()));
1024    }
1025
1026    #[test]
1027    fn test_detect_gpus_returns_vec() {
1028        // On CI/dev machines without GPUs this should return an empty vec
1029        // On machines with GPUs it should return valid entries
1030        let gpus = detect_gpus();
1031        for gpu in &gpus {
1032            assert!(!gpu.pci_bus_id.is_empty());
1033            assert!(!gpu.vendor.is_empty());
1034            assert!(!gpu.model.is_empty());
1035            assert!(!gpu.device_path.is_empty());
1036        }
1037    }
1038}