Skip to main content

zlayer_agent/
gpu_detector.rs

1//! GPU inventory detection
2//!
3//! Platform-specific GPU detection:
4//! - **Linux**: Scans `/sys/bus/pci/devices` for display controllers (VGA and 3D controllers).
5//!   Identifies vendor (NVIDIA, AMD, Intel) by PCI vendor ID, reads VRAM from PCI BAR regions,
6//!   and optionally uses `nvidia-smi` for NVIDIA-specific model and memory information.
7//! - **macOS**: Uses `system_profiler SPDisplaysDataType -json` to detect Apple Silicon GPUs
8//!   and unified memory via `sysctl -n hw.memsize`.
9//! - **Other**: Returns an empty GPU list.
10//!
11//! No external dependencies required -- pure `sysfs/system_profiler` scanning with optional
12//! subprocess calls for enrichment.
13
14use serde::{Deserialize, Serialize};
15
16/// Detected GPU information
17#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
18pub struct GpuInfo {
19    /// PCI bus ID (e.g., "0000:01:00.0" on Linux, "apple:0" on macOS)
20    pub pci_bus_id: String,
21    /// Vendor: "nvidia", "amd", "intel", "apple", or "unknown"
22    pub vendor: String,
23    /// Model name (e.g., "Apple M2 Pro" or "NVIDIA A100-SXM4-80GB")
24    pub model: String,
25    /// VRAM in MB (0 if unknown; on Apple Silicon, this is unified memory)
26    pub memory_mb: u64,
27    /// Device path (e.g., "/dev/nvidia0", "/dev/dri/card0", "<iokit://AppleGPU/0>")
28    pub device_path: String,
29    /// Render node path if applicable (e.g., "/dev/dri/renderD128"); None on macOS
30    pub render_path: Option<String>,
31}
32
33// =============================================================================
34// Linux GPU detection
35// =============================================================================
36
37/// Scan the system for GPU devices via sysfs PCI enumeration (Linux only)
38///
39/// Iterates over `/sys/bus/pci/devices` looking for PCI class codes that
40/// indicate display controllers:
41/// - `0x0300xx` -- VGA compatible controller
42/// - `0x0302xx` -- 3D controller (e.g., NVIDIA Tesla/datacenter GPUs)
43///
44/// For each GPU found, determines vendor, model name, VRAM, and device paths.
45#[cfg(target_os = "linux")]
46#[must_use]
47pub fn detect_gpus() -> Vec<GpuInfo> {
48    use std::path::Path;
49
50    let mut gpus = Vec::new();
51
52    let pci_dir = Path::new("/sys/bus/pci/devices");
53    if !pci_dir.exists() {
54        return gpus;
55    }
56
57    let Ok(entries) = std::fs::read_dir(pci_dir) else {
58        return gpus;
59    };
60
61    // Optionally pre-fetch nvidia-smi data once for all NVIDIA GPUs
62    let nvidia_data = NvidiaSmiData::fetch();
63
64    for entry in entries.flatten() {
65        let device_dir = entry.path();
66
67        // Read PCI device class
68        let class_path = device_dir.join("class");
69        let class = match std::fs::read_to_string(&class_path) {
70            Ok(c) => c.trim().to_string(),
71            Err(_) => continue,
72        };
73
74        // Filter to display controllers only
75        if !class.starts_with("0x0302") && !class.starts_with("0x0300") {
76            continue;
77        }
78
79        // Read PCI vendor ID
80        let vendor_path = device_dir.join("vendor");
81        let vendor_id = std::fs::read_to_string(&vendor_path)
82            .unwrap_or_default()
83            .trim()
84            .to_string();
85
86        let vendor = match vendor_id.as_str() {
87            "0x10de" => "nvidia",
88            "0x1002" => "amd",
89            "0x8086" => "intel",
90            _ => "unknown",
91        }
92        .to_string();
93
94        let pci_bus_id = entry.file_name().to_string_lossy().to_string();
95
96        // Count how many GPUs of this vendor we've already seen (for device path indexing)
97        let vendor_index = gpus
98            .iter()
99            .filter(|g: &&GpuInfo| g.vendor == vendor)
100            .count();
101
102        let model = read_gpu_model(&device_dir, &vendor, &nvidia_data, vendor_index);
103        let memory_mb = read_gpu_memory(&device_dir, &vendor, &nvidia_data, vendor_index);
104        let (device_path, render_path) = find_device_paths(&pci_bus_id, &vendor, vendor_index);
105
106        gpus.push(GpuInfo {
107            pci_bus_id,
108            vendor,
109            model,
110            memory_mb,
111            device_path,
112            render_path,
113        });
114    }
115
116    gpus
117}
118
119// =============================================================================
120// macOS GPU detection
121// =============================================================================
122
123/// Detect Apple Silicon GPUs via `system_profiler` (macOS only)
124///
125/// Runs `system_profiler SPDisplaysDataType -json` to enumerate GPUs, then
126/// queries `sysctl -n hw.memsize` for the unified memory pool size. Apple Silicon
127/// shares system memory between CPU and GPU, so the full physical memory is
128/// reported as the GPU's available memory.
129#[cfg(target_os = "macos")]
130#[must_use]
131pub fn detect_gpus() -> Vec<GpuInfo> {
132    detect_apple_gpus()
133}
134
135/// Internal macOS GPU detection implementation
136#[cfg(target_os = "macos")]
137fn detect_apple_gpus() -> Vec<GpuInfo> {
138    let output = match std::process::Command::new("system_profiler")
139        .args(["SPDisplaysDataType", "-json"])
140        .output()
141    {
142        Ok(out) if out.status.success() => out,
143        _ => return Vec::new(),
144    };
145
146    let json_str = String::from_utf8_lossy(&output.stdout);
147    let parsed: serde_json::Value = match serde_json::from_str(&json_str) {
148        Ok(v) => v,
149        Err(_) => return Vec::new(),
150    };
151
152    let unified_memory_mb = detect_unified_memory_mb();
153
154    let mut gpus = Vec::new();
155
156    // system_profiler returns { "SPDisplaysDataType": [ { ... }, ... ] }
157    let Some(displays) = parsed.get("SPDisplaysDataType").and_then(|v| v.as_array()) else {
158        return gpus;
159    };
160
161    for (idx, display) in displays.iter().enumerate() {
162        let model = display
163            .get("sppci_model")
164            .and_then(|v| v.as_str())
165            .or_else(|| display.get("_name").and_then(|v| v.as_str()))
166            .unwrap_or("Apple GPU")
167            .to_string();
168
169        let chip_type = display
170            .get("sppci_chiptype")
171            .and_then(|v| v.as_str())
172            .unwrap_or("");
173
174        // Use chip type in model name if the model doesn't already include it
175        let model = if !chip_type.is_empty() && !model.contains(chip_type) {
176            format!("{model} ({chip_type})")
177        } else {
178            model
179        };
180
181        // Apple Silicon uses unified memory -- report the full system memory
182        // as GPU-accessible memory. For discrete AMD GPUs in older Macs,
183        // try to read the VRAM field from system_profiler.
184        let memory_mb = display
185            .get("sppci_vram")
186            .and_then(|v| v.as_str())
187            .and_then(|s| {
188                // Format is like "16 GB" or "8192 MB"
189                let parts: Vec<&str> = s.split_whitespace().collect();
190                if parts.len() >= 2 {
191                    let amount: u64 = parts[0].parse().ok()?;
192                    match parts[1].to_uppercase().as_str() {
193                        "GB" => Some(amount * 1024),
194                        "MB" => Some(amount),
195                        _ => None,
196                    }
197                } else {
198                    None
199                }
200            })
201            .unwrap_or(unified_memory_mb);
202
203        let vendor_str = display
204            .get("sppci_vendor")
205            .and_then(|v| v.as_str())
206            .unwrap_or("");
207
208        // Determine vendor from vendor string or chip type
209        let vendor = if vendor_str.to_lowercase().contains("apple")
210            || chip_type.to_lowercase().starts_with("apple")
211            || model.to_lowercase().contains("apple m")
212        {
213            "apple".to_string()
214        } else if vendor_str.to_lowercase().contains("amd")
215            || vendor_str.to_lowercase().contains("ati")
216        {
217            "amd".to_string()
218        } else if vendor_str.to_lowercase().contains("intel") {
219            "intel".to_string()
220        } else {
221            // Default to "apple" on macOS when vendor is ambiguous
222            "apple".to_string()
223        };
224
225        gpus.push(GpuInfo {
226            pci_bus_id: format!("apple:{idx}"),
227            vendor,
228            model,
229            memory_mb,
230            device_path: format!("iokit://AppleGPU/{idx}"),
231            render_path: None,
232        });
233    }
234
235    gpus
236}
237
238/// Query unified memory size via sysctl on macOS
239#[cfg(target_os = "macos")]
240fn detect_unified_memory_mb() -> u64 {
241    let output = match std::process::Command::new("sysctl")
242        .args(["-n", "hw.memsize"])
243        .output()
244    {
245        Ok(out) if out.status.success() => out,
246        _ => return 0,
247    };
248
249    let text = String::from_utf8_lossy(&output.stdout);
250    text.trim()
251        .parse::<u64>()
252        .map(|bytes| bytes / (1024 * 1024))
253        .unwrap_or(0)
254}
255
256// =============================================================================
257// Fallback for unsupported platforms
258// =============================================================================
259
260/// Returns an empty GPU list on unsupported platforms
261#[cfg(not(any(target_os = "linux", target_os = "macos")))]
262#[must_use]
263pub fn detect_gpus() -> Vec<GpuInfo> {
264    Vec::new()
265}
266
267// =============================================================================
268// Linux-only helpers: nvidia-smi, sysfs scanning
269// =============================================================================
270
271/// Pre-fetched nvidia-smi data to avoid calling the subprocess multiple times
272#[cfg(target_os = "linux")]
273struct NvidiaSmiData {
274    /// GPU names, one per line
275    names: Vec<String>,
276    /// GPU memory in MB, one per line
277    memories: Vec<u64>,
278}
279
280#[cfg(target_os = "linux")]
281impl NvidiaSmiData {
282    /// Attempt to fetch GPU info from nvidia-smi. Returns empty data on failure.
283    fn fetch() -> Self {
284        let names = Self::query("name");
285        let memories = Self::query("memory.total")
286            .iter()
287            .map(|s| s.trim().parse::<u64>().unwrap_or(0))
288            .collect();
289
290        Self { names, memories }
291    }
292
293    fn query(field: &str) -> Vec<String> {
294        let output = std::process::Command::new("nvidia-smi")
295            .args([
296                &format!("--query-gpu={field}"),
297                "--format=csv,noheader,nounits",
298            ])
299            .output();
300
301        match output {
302            Ok(out) if out.status.success() => {
303                let text = String::from_utf8_lossy(&out.stdout);
304                text.lines().map(|l| l.trim().to_string()).collect()
305            }
306            _ => Vec::new(),
307        }
308    }
309}
310
311// =============================================================================
312// Model detection (Linux only)
313// =============================================================================
314
315/// Read GPU model name from sysfs or nvidia-smi
316#[cfg(target_os = "linux")]
317fn read_gpu_model(
318    device_dir: &std::path::Path,
319    vendor: &str,
320    nvidia_data: &NvidiaSmiData,
321    vendor_index: usize,
322) -> String {
323    // Try DRM subsystem product name first (works for all vendors on recent kernels)
324    if let Some(name) = read_drm_product_name(device_dir) {
325        return name;
326    }
327
328    match vendor {
329        "nvidia" => {
330            // Use pre-fetched nvidia-smi data
331            if let Some(name) = nvidia_data.names.get(vendor_index) {
332                if !name.is_empty() {
333                    return name.clone();
334                }
335            }
336            "NVIDIA GPU".to_string()
337        }
338        "amd" => "AMD GPU".to_string(),
339        "intel" => "Intel GPU".to_string(),
340        _ => "Unknown GPU".to_string(),
341    }
342}
343
344/// Try to read GPU product name from the DRM subsystem
345///
346/// Checks `/sys/bus/pci/devices/XXXX/drm/cardN/device/product_name` and similar paths.
347#[cfg(target_os = "linux")]
348fn read_drm_product_name(device_dir: &std::path::Path) -> Option<String> {
349    // Try the product_name file under the PCI device
350    let product_name_path = device_dir.join("label");
351    if let Ok(name) = std::fs::read_to_string(&product_name_path) {
352        let name = name.trim().to_string();
353        if !name.is_empty() {
354            return Some(name);
355        }
356    }
357
358    // Try reading from the DRM card's device directory
359    let drm_dir = device_dir.join("drm");
360    if let Ok(entries) = std::fs::read_dir(&drm_dir) {
361        for entry in entries.flatten() {
362            let name = entry.file_name();
363            let name_str = name.to_string_lossy();
364            if name_str.starts_with("card") {
365                let product_path = entry.path().join("device").join("product_name");
366                if let Ok(product) = std::fs::read_to_string(&product_path) {
367                    let product = product.trim().to_string();
368                    if !product.is_empty() {
369                        return Some(product);
370                    }
371                }
372            }
373        }
374    }
375
376    None
377}
378
379// =============================================================================
380// VRAM detection (Linux only)
381// =============================================================================
382
383/// Read GPU VRAM from sysfs PCI BAR regions or nvidia-smi
384#[cfg(target_os = "linux")]
385fn read_gpu_memory(
386    device_dir: &std::path::Path,
387    vendor: &str,
388    nvidia_data: &NvidiaSmiData,
389    vendor_index: usize,
390) -> u64 {
391    // For NVIDIA, prefer nvidia-smi data (more accurate than PCI BAR)
392    if vendor == "nvidia" {
393        if let Some(&mem) = nvidia_data.memories.get(vendor_index) {
394            if mem > 0 {
395                return mem;
396            }
397        }
398    }
399
400    // For AMD, try the VRAM-specific sysfs file
401    if vendor == "amd" {
402        let vram_path = device_dir.join("mem_info_vram_total");
403        if let Ok(content) = std::fs::read_to_string(&vram_path) {
404            if let Ok(bytes) = content.trim().parse::<u64>() {
405                return bytes / (1024 * 1024);
406            }
407        }
408    }
409
410    // Fall back to reading PCI resource file for BAR sizes
411    // The largest BAR region is typically VRAM
412    let resource_path = device_dir.join("resource");
413    if let Ok(content) = std::fs::read_to_string(&resource_path) {
414        let mut max_size: u64 = 0;
415        for line in content.lines() {
416            let parts: Vec<&str> = line.split_whitespace().collect();
417            if parts.len() >= 2 {
418                if let (Ok(start), Ok(end)) = (
419                    u64::from_str_radix(parts[0].trim_start_matches("0x"), 16),
420                    u64::from_str_radix(parts[1].trim_start_matches("0x"), 16),
421                ) {
422                    if end > start {
423                        let size = end - start + 1;
424                        if size > max_size {
425                            max_size = size;
426                        }
427                    }
428                }
429            }
430        }
431        if max_size > 0 {
432            return max_size / (1024 * 1024);
433        }
434    }
435
436    0
437}
438
439// =============================================================================
440// Device path resolution (Linux only)
441// =============================================================================
442
443/// Find device paths for a GPU based on vendor and index
444#[cfg(target_os = "linux")]
445fn find_device_paths(
446    _pci_bus_id: &str,
447    vendor: &str,
448    vendor_index: usize,
449) -> (String, Option<String>) {
450    if vendor == "nvidia" {
451        let dev = format!("/dev/nvidia{vendor_index}");
452        (dev, None)
453    } else {
454        // AMD, Intel, and unknown vendors use DRI device nodes
455        let card = format!("/dev/dri/card{vendor_index}");
456        let render = format!("/dev/dri/renderD{}", 128 + vendor_index);
457        (card, Some(render))
458    }
459}
460
461// =============================================================================
462// Tests
463// =============================================================================
464
465#[cfg(test)]
466mod tests {
467    use super::*;
468
469    #[test]
470    fn test_gpu_info_serialization_roundtrip() {
471        let info = GpuInfo {
472            pci_bus_id: "0000:01:00.0".to_string(),
473            vendor: "nvidia".to_string(),
474            model: "NVIDIA A100-SXM4-80GB".to_string(),
475            memory_mb: 81920,
476            device_path: "/dev/nvidia0".to_string(),
477            render_path: None,
478        };
479
480        let json = serde_json::to_string(&info).unwrap();
481        let deserialized: GpuInfo = serde_json::from_str(&json).unwrap();
482        assert_eq!(info, deserialized);
483    }
484
485    #[test]
486    fn test_gpu_info_amd_serialization() {
487        let info = GpuInfo {
488            pci_bus_id: "0000:03:00.0".to_string(),
489            vendor: "amd".to_string(),
490            model: "AMD GPU".to_string(),
491            memory_mb: 16384,
492            device_path: "/dev/dri/card0".to_string(),
493            render_path: Some("/dev/dri/renderD128".to_string()),
494        };
495
496        let json = serde_json::to_string(&info).unwrap();
497        let deserialized: GpuInfo = serde_json::from_str(&json).unwrap();
498        assert_eq!(info, deserialized);
499    }
500
501    #[test]
502    fn test_gpu_info_apple_serialization() {
503        let info = GpuInfo {
504            pci_bus_id: "apple:0".to_string(),
505            vendor: "apple".to_string(),
506            model: "Apple M2 Pro".to_string(),
507            memory_mb: 32768,
508            device_path: "iokit://AppleGPU/0".to_string(),
509            render_path: None,
510        };
511
512        let json = serde_json::to_string(&info).unwrap();
513        let deserialized: GpuInfo = serde_json::from_str(&json).unwrap();
514        assert_eq!(info, deserialized);
515    }
516
517    #[cfg(target_os = "linux")]
518    #[test]
519    fn test_find_device_paths_nvidia() {
520        let (dev, render) = find_device_paths("0000:01:00.0", "nvidia", 0);
521        assert_eq!(dev, "/dev/nvidia0");
522        assert!(render.is_none());
523
524        let (dev, render) = find_device_paths("0000:02:00.0", "nvidia", 1);
525        assert_eq!(dev, "/dev/nvidia1");
526        assert!(render.is_none());
527    }
528
529    #[cfg(target_os = "linux")]
530    #[test]
531    fn test_find_device_paths_amd() {
532        let (dev, render) = find_device_paths("0000:03:00.0", "amd", 0);
533        assert_eq!(dev, "/dev/dri/card0");
534        assert_eq!(render, Some("/dev/dri/renderD128".to_string()));
535    }
536
537    #[cfg(target_os = "linux")]
538    #[test]
539    fn test_find_device_paths_intel() {
540        let (dev, render) = find_device_paths("0000:00:02.0", "intel", 0);
541        assert_eq!(dev, "/dev/dri/card0");
542        assert_eq!(render, Some("/dev/dri/renderD128".to_string()));
543    }
544
545    #[test]
546    fn test_detect_gpus_returns_vec() {
547        // On CI/dev machines without GPUs this should return an empty vec
548        // On machines with GPUs it should return valid entries
549        let gpus = detect_gpus();
550        for gpu in &gpus {
551            assert!(!gpu.pci_bus_id.is_empty());
552            assert!(!gpu.vendor.is_empty());
553            assert!(!gpu.model.is_empty());
554            assert!(!gpu.device_path.is_empty());
555        }
556    }
557}