1use serde::{Deserialize, Serialize};
22
23#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
25pub struct GpuInfo {
26 pub pci_bus_id: String,
28 pub vendor: String,
30 pub model: String,
32 pub memory_mb: u64,
34 pub device_path: String,
36 pub render_path: Option<String>,
38}
39
40#[cfg(target_os = "linux")]
53#[must_use]
54pub fn detect_gpus() -> Vec<GpuInfo> {
55 use std::path::Path;
56
57 let mut gpus = Vec::new();
58
59 let pci_dir = Path::new("/sys/bus/pci/devices");
60 if !pci_dir.exists() {
61 return gpus;
62 }
63
64 let Ok(entries) = std::fs::read_dir(pci_dir) else {
65 return gpus;
66 };
67
68 let nvidia_data = NvidiaSmiData::fetch();
70
71 for entry in entries.flatten() {
72 let device_dir = entry.path();
73
74 let class_path = device_dir.join("class");
76 let class = match std::fs::read_to_string(&class_path) {
77 Ok(c) => c.trim().to_string(),
78 Err(_) => continue,
79 };
80
81 if !class.starts_with("0x0302") && !class.starts_with("0x0300") {
83 continue;
84 }
85
86 let vendor_path = device_dir.join("vendor");
88 let vendor_id = std::fs::read_to_string(&vendor_path)
89 .unwrap_or_default()
90 .trim()
91 .to_string();
92
93 let vendor = match vendor_id.as_str() {
94 "0x10de" => "nvidia",
95 "0x1002" => "amd",
96 "0x8086" => "intel",
97 _ => "unknown",
98 }
99 .to_string();
100
101 let pci_bus_id = entry.file_name().to_string_lossy().to_string();
102
103 let vendor_index = gpus
105 .iter()
106 .filter(|g: &&GpuInfo| g.vendor == vendor)
107 .count();
108
109 let model = read_gpu_model(&device_dir, &vendor, &nvidia_data, vendor_index);
110 let memory_mb = read_gpu_memory(&device_dir, &vendor, &nvidia_data, vendor_index);
111 let (device_path, render_path) = find_device_paths(&pci_bus_id, &vendor, vendor_index);
112
113 gpus.push(GpuInfo {
114 pci_bus_id,
115 vendor,
116 model,
117 memory_mb,
118 device_path,
119 render_path,
120 });
121 }
122
123 gpus
124}
125
126#[cfg(target_os = "macos")]
137#[must_use]
138pub fn detect_gpus() -> Vec<GpuInfo> {
139 detect_apple_gpus()
140}
141
142#[cfg(target_os = "macos")]
144fn detect_apple_gpus() -> Vec<GpuInfo> {
145 let output = match std::process::Command::new("system_profiler")
146 .args(["SPDisplaysDataType", "-json"])
147 .output()
148 {
149 Ok(out) if out.status.success() => out,
150 _ => return Vec::new(),
151 };
152
153 let json_str = String::from_utf8_lossy(&output.stdout);
154 let parsed: serde_json::Value = match serde_json::from_str(&json_str) {
155 Ok(v) => v,
156 Err(_) => return Vec::new(),
157 };
158
159 let unified_memory_mb = detect_unified_memory_mb();
160
161 let mut gpus = Vec::new();
162
163 let Some(displays) = parsed.get("SPDisplaysDataType").and_then(|v| v.as_array()) else {
165 return gpus;
166 };
167
168 for (idx, display) in displays.iter().enumerate() {
169 let model = display
170 .get("sppci_model")
171 .and_then(|v| v.as_str())
172 .or_else(|| display.get("_name").and_then(|v| v.as_str()))
173 .unwrap_or("Apple GPU")
174 .to_string();
175
176 let chip_type = display
177 .get("sppci_chiptype")
178 .and_then(|v| v.as_str())
179 .unwrap_or("");
180
181 let model = if !chip_type.is_empty() && !model.contains(chip_type) {
183 format!("{model} ({chip_type})")
184 } else {
185 model
186 };
187
188 let memory_mb = display
192 .get("sppci_vram")
193 .and_then(|v| v.as_str())
194 .and_then(|s| {
195 let parts: Vec<&str> = s.split_whitespace().collect();
197 if parts.len() >= 2 {
198 let amount: u64 = parts[0].parse().ok()?;
199 match parts[1].to_uppercase().as_str() {
200 "GB" => Some(amount * 1024),
201 "MB" => Some(amount),
202 _ => None,
203 }
204 } else {
205 None
206 }
207 })
208 .unwrap_or(unified_memory_mb);
209
210 let vendor_str = display
211 .get("sppci_vendor")
212 .and_then(|v| v.as_str())
213 .unwrap_or("");
214
215 let vendor = if vendor_str.to_lowercase().contains("apple")
217 || chip_type.to_lowercase().starts_with("apple")
218 || model.to_lowercase().contains("apple m")
219 {
220 "apple".to_string()
221 } else if vendor_str.to_lowercase().contains("amd")
222 || vendor_str.to_lowercase().contains("ati")
223 {
224 "amd".to_string()
225 } else if vendor_str.to_lowercase().contains("intel") {
226 "intel".to_string()
227 } else {
228 "apple".to_string()
230 };
231
232 gpus.push(GpuInfo {
233 pci_bus_id: format!("apple:{idx}"),
234 vendor,
235 model,
236 memory_mb,
237 device_path: format!("iokit://AppleGPU/{idx}"),
238 render_path: None,
239 });
240 }
241
242 gpus
243}
244
245#[cfg(target_os = "macos")]
247fn detect_unified_memory_mb() -> u64 {
248 let output = match std::process::Command::new("sysctl")
249 .args(["-n", "hw.memsize"])
250 .output()
251 {
252 Ok(out) if out.status.success() => out,
253 _ => return 0,
254 };
255
256 let text = String::from_utf8_lossy(&output.stdout);
257 text.trim()
258 .parse::<u64>()
259 .map_or(0, |bytes| bytes / (1024 * 1024))
260}
261
262#[cfg(target_os = "windows")]
282#[must_use]
283pub fn detect_gpus() -> Vec<GpuInfo> {
284 windows_impl::detect_gpus_windows()
285}
286
287#[cfg(not(any(target_os = "linux", target_os = "macos", target_os = "windows")))]
293#[must_use]
294pub fn detect_gpus() -> Vec<GpuInfo> {
295 Vec::new()
296}
297
298#[cfg(target_os = "linux")]
304struct NvidiaSmiData {
305 names: Vec<String>,
307 memories: Vec<u64>,
309}
310
311#[cfg(target_os = "linux")]
312impl NvidiaSmiData {
313 fn fetch() -> Self {
315 let names = Self::query("name");
316 let memories = Self::query("memory.total")
317 .iter()
318 .map(|s| s.trim().parse::<u64>().unwrap_or(0))
319 .collect();
320
321 Self { names, memories }
322 }
323
324 fn query(field: &str) -> Vec<String> {
325 let output = std::process::Command::new("nvidia-smi")
326 .args([
327 &format!("--query-gpu={field}"),
328 "--format=csv,noheader,nounits",
329 ])
330 .output();
331
332 match output {
333 Ok(out) if out.status.success() => {
334 let text = String::from_utf8_lossy(&out.stdout);
335 text.lines().map(|l| l.trim().to_string()).collect()
336 }
337 _ => Vec::new(),
338 }
339 }
340}
341
342#[cfg(target_os = "linux")]
348fn read_gpu_model(
349 device_dir: &std::path::Path,
350 vendor: &str,
351 nvidia_data: &NvidiaSmiData,
352 vendor_index: usize,
353) -> String {
354 if let Some(name) = read_drm_product_name(device_dir) {
356 return name;
357 }
358
359 match vendor {
360 "nvidia" => {
361 if let Some(name) = nvidia_data.names.get(vendor_index) {
363 if !name.is_empty() {
364 return name.clone();
365 }
366 }
367 "NVIDIA GPU".to_string()
368 }
369 "amd" => "AMD GPU".to_string(),
370 "intel" => "Intel GPU".to_string(),
371 _ => "Unknown GPU".to_string(),
372 }
373}
374
375#[cfg(target_os = "linux")]
379fn read_drm_product_name(device_dir: &std::path::Path) -> Option<String> {
380 let product_name_path = device_dir.join("label");
382 if let Ok(name) = std::fs::read_to_string(&product_name_path) {
383 let name = name.trim().to_string();
384 if !name.is_empty() {
385 return Some(name);
386 }
387 }
388
389 let drm_dir = device_dir.join("drm");
391 if let Ok(entries) = std::fs::read_dir(&drm_dir) {
392 for entry in entries.flatten() {
393 let name = entry.file_name();
394 let name_str = name.to_string_lossy();
395 if name_str.starts_with("card") {
396 let product_path = entry.path().join("device").join("product_name");
397 if let Ok(product) = std::fs::read_to_string(&product_path) {
398 let product = product.trim().to_string();
399 if !product.is_empty() {
400 return Some(product);
401 }
402 }
403 }
404 }
405 }
406
407 None
408}
409
410#[cfg(target_os = "linux")]
416fn read_gpu_memory(
417 device_dir: &std::path::Path,
418 vendor: &str,
419 nvidia_data: &NvidiaSmiData,
420 vendor_index: usize,
421) -> u64 {
422 if vendor == "nvidia" {
424 if let Some(&mem) = nvidia_data.memories.get(vendor_index) {
425 if mem > 0 {
426 return mem;
427 }
428 }
429 }
430
431 if vendor == "amd" {
433 let vram_path = device_dir.join("mem_info_vram_total");
434 if let Ok(content) = std::fs::read_to_string(&vram_path) {
435 if let Ok(bytes) = content.trim().parse::<u64>() {
436 return bytes / (1024 * 1024);
437 }
438 }
439 }
440
441 let resource_path = device_dir.join("resource");
444 if let Ok(content) = std::fs::read_to_string(&resource_path) {
445 let mut max_size: u64 = 0;
446 for line in content.lines() {
447 let parts: Vec<&str> = line.split_whitespace().collect();
448 if parts.len() >= 2 {
449 if let (Ok(start), Ok(end)) = (
450 u64::from_str_radix(parts[0].trim_start_matches("0x"), 16),
451 u64::from_str_radix(parts[1].trim_start_matches("0x"), 16),
452 ) {
453 if end > start {
454 let size = end - start + 1;
455 if size > max_size {
456 max_size = size;
457 }
458 }
459 }
460 }
461 }
462 if max_size > 0 {
463 return max_size / (1024 * 1024);
464 }
465 }
466
467 0
468}
469
470#[cfg(target_os = "linux")]
476fn find_device_paths(
477 _pci_bus_id: &str,
478 vendor: &str,
479 vendor_index: usize,
480) -> (String, Option<String>) {
481 if vendor == "nvidia" {
482 let dev = format!("/dev/nvidia{vendor_index}");
483 (dev, None)
484 } else {
485 let card = format!("/dev/dri/card{vendor_index}");
487 let render = format!("/dev/dri/renderD{}", 128 + vendor_index);
488 (card, Some(render))
489 }
490}
491
492#[cfg(target_os = "windows")]
497mod windows_impl {
498 use super::GpuInfo;
499 use std::collections::HashMap;
500 use wmi::{Variant, WMIConnection};
501
502 const DISPLAY_CLASS_GUID: &str = "{4d36e968-e325-11ce-bfc1-08002be10318}";
508
509 pub fn detect_gpus_windows() -> Vec<GpuInfo> {
511 let mut gpus: Vec<GpuInfo> = Vec::new();
512
513 let nvml_gpus = detect_via_nvml();
515 gpus.extend(nvml_gpus);
516
517 let wmi_gpus = match detect_via_wmi() {
519 Ok(v) => v,
520 Err(e) => {
521 tracing::warn!(
522 "WMI Win32_VideoController query failed: {e}; \
523 Windows GPU detection falling back to NVML-only"
524 );
525 Vec::new()
526 }
527 };
528
529 let amd_registry = collect_amd_registry_vram().unwrap_or_default();
531
532 for mut wmi_gpu in wmi_gpus {
533 if wmi_gpu.vendor == "amd" {
535 if let Some(key) = pci_key_from_bus_id(&wmi_gpu.pci_bus_id) {
536 if let Some(&vram_bytes) = amd_registry.get(&key) {
537 wmi_gpu.memory_mb = vram_bytes / (1024 * 1024);
538 }
539 }
540 }
541
542 if let Some(existing) = gpus.iter_mut().find(|g| g.pci_bus_id == wmi_gpu.pci_bus_id) {
544 if existing.model.trim().is_empty() || existing.model == "NVIDIA GPU" {
547 wmi_gpu.model.clone_into(&mut existing.model);
548 }
549 continue;
550 }
551
552 gpus.push(wmi_gpu);
553 }
554
555 gpus
556 }
557
558 fn detect_via_nvml() -> Vec<GpuInfo> {
563 let nvml = match nvml_wrapper::Nvml::init() {
567 Ok(n) => n,
568 Err(e) => {
569 tracing::debug!("NVML unavailable (no NVIDIA driver?): {e}");
570 return Vec::new();
571 }
572 };
573
574 let count = match nvml.device_count() {
575 Ok(c) => c,
576 Err(e) => {
577 tracing::debug!("nvmlDeviceGetCount failed: {e}");
578 return Vec::new();
579 }
580 };
581
582 let mut out = Vec::with_capacity(count as usize);
583 for i in 0..count {
584 let device = match nvml.device_by_index(i) {
585 Ok(d) => d,
586 Err(e) => {
587 tracing::debug!("nvmlDeviceGetHandleByIndex({i}) failed: {e}");
588 continue;
589 }
590 };
591
592 let model = device.name().unwrap_or_else(|_| "NVIDIA GPU".to_string());
593
594 let memory_mb = device.memory_info().map_or(0, |m| m.total / (1024 * 1024));
596
597 let pci_bus_id = match device.pci_info() {
598 Ok(info) => canonicalize_pci_bus_id(&info.bus_id),
599 Err(e) => {
600 tracing::debug!("nvmlDeviceGetPciInfo({i}) failed: {e}");
601 format!("nvml:{i}")
604 }
605 };
606
607 out.push(GpuInfo {
608 pci_bus_id,
609 vendor: "nvidia".to_string(),
610 model,
611 memory_mb,
612 device_path: format!("nvml://{i}"),
615 render_path: None,
616 });
617 }
618
619 out
620 }
621
622 fn canonicalize_pci_bus_id(raw: &str) -> String {
627 let mut parts = raw.splitn(2, ':');
628 let (Some(domain), Some(rest)) = (parts.next(), parts.next()) else {
629 return raw.to_ascii_lowercase();
630 };
631 let trimmed_domain = domain.trim_start_matches('0').to_string();
632 let domain_out = if trimmed_domain.is_empty() {
633 "0000".to_string()
634 } else if trimmed_domain.len() < 4 {
635 format!("{trimmed_domain:0>4}")
636 } else {
637 trimmed_domain
638 };
639 format!("{domain_out}:{rest}").to_ascii_lowercase()
640 }
641
642 fn detect_via_wmi() -> Result<Vec<GpuInfo>, String> {
647 let wmi = WMIConnection::new().map_err(|e| format!("WMIConnection::new: {e}"))?;
652
653 let query = "SELECT Name, PNPDeviceID, AdapterRAM \
657 FROM Win32_VideoController \
658 WHERE PNPDeviceID LIKE 'PCI\\\\VEN_%'";
659
660 let rows: Vec<HashMap<String, Variant>> = wmi
661 .raw_query(query)
662 .map_err(|e| format!("raw_query({query}): {e}"))?;
663
664 let mut out = Vec::with_capacity(rows.len());
665 for row in rows {
666 let Some(pnp) = variant_string(row.get("PNPDeviceID")) else {
667 continue;
668 };
669 let Some((vendor_id, device_id)) = parse_ven_dev(&pnp) else {
670 continue;
671 };
672
673 let vendor = vendor_from_pci_id(vendor_id);
674 let model = variant_string(row.get("Name"))
675 .unwrap_or_else(|| format!("{} GPU", vendor.to_ascii_uppercase()));
676
677 let memory_mb = variant_u64(row.get("AdapterRAM")).map_or(0, |b| b / (1024 * 1024));
681
682 let pci_bus_id = pci_bus_id_from_pnp(&pnp, vendor_id, device_id);
683
684 out.push(GpuInfo {
685 pci_bus_id,
686 vendor: vendor.to_string(),
687 model,
688 memory_mb,
689 device_path: format!(r"\\.\DISPLAY#{pnp}"),
690 render_path: None,
691 });
692 }
693
694 Ok(out)
695 }
696
697 fn parse_ven_dev(pnp: &str) -> Option<(u16, u16)> {
700 let upper = pnp.to_ascii_uppercase();
703 let ven = extract_hex(&upper, "VEN_", 4)?;
704 let dev = extract_hex(&upper, "DEV_", 4)?;
705 Some((ven, dev))
706 }
707
708 fn extract_hex(s: &str, marker: &str, nibbles: usize) -> Option<u16> {
709 let start = s.find(marker)? + marker.len();
710 let hex = s.get(start..start + nibbles)?;
711 u16::from_str_radix(hex, 16).ok()
712 }
713
714 fn vendor_from_pci_id(vendor_id: u16) -> &'static str {
715 match vendor_id {
716 0x10DE => "nvidia",
717 0x1002 | 0x1022 => "amd",
719 0x8086 => "intel",
720 _ => "unknown",
721 }
722 }
723
724 fn pci_bus_id_from_pnp(pnp: &str, vendor_id: u16, device_id: u16) -> String {
735 let slot = pnp
739 .rsplit_once('&')
740 .and_then(|(_, tail)| tail.chars().take(4).collect::<String>().parse::<u16>().ok())
741 .unwrap_or(0);
742 format!("0000:{vendor_id:04x}:{device_id:04x}.{slot:x}")
743 }
744
745 fn variant_string(v: Option<&Variant>) -> Option<String> {
746 match v? {
747 Variant::String(s) => Some(s.clone()),
748 _ => None,
749 }
750 }
751
752 fn variant_u64(v: Option<&Variant>) -> Option<u64> {
753 match v? {
754 Variant::UI1(n) => Some(u64::from(*n)),
755 Variant::UI2(n) => Some(u64::from(*n)),
756 Variant::UI4(n) => Some(u64::from(*n)),
757 Variant::UI8(n) => Some(*n),
758 Variant::I1(n) if *n >= 0 => Some(u64::try_from(*n).unwrap_or(0)),
759 Variant::I2(n) if *n >= 0 => Some(u64::try_from(*n).unwrap_or(0)),
760 Variant::I4(n) if *n >= 0 => Some(u64::try_from(*n).unwrap_or(0)),
761 Variant::I8(n) if *n >= 0 => Some(u64::try_from(*n).unwrap_or(0)),
762 _ => None,
763 }
764 }
765
766 fn collect_amd_registry_vram() -> Result<HashMap<(u16, u16), u64>, String> {
774 let class_path = format!(r"SYSTEM\CurrentControlSet\Control\Class\{DISPLAY_CLASS_GUID}");
775 let class_key = windows_registry::LOCAL_MACHINE
776 .open(&class_path)
777 .map_err(|e| format!("open HKLM\\{class_path}: {e}"))?;
778
779 let mut out: HashMap<(u16, u16), u64> = HashMap::new();
780
781 let subkeys = match class_key.keys() {
785 Ok(it) => it,
786 Err(e) => return Err(format!("enumerate class subkeys: {e}")),
787 };
788
789 for name in subkeys {
790 if name.len() != 4 || !name.chars().all(|c| c.is_ascii_digit()) {
793 continue;
794 }
795
796 let Ok(adapter_key) = class_key.open(&name) else {
797 continue;
798 };
799
800 let vendor_id = adapter_key
801 .get_string("MatchingDeviceId")
802 .ok()
803 .as_deref()
804 .and_then(parse_matching_device_id);
805 let Some((ven, dev)) = vendor_id else {
806 continue;
807 };
808
809 if vendor_from_pci_id(ven) != "amd" {
810 continue;
811 }
812
813 if let Ok(bytes) = adapter_key
818 .open("HardwareInformation")
819 .and_then(|hw| hw.get_u64("qwMemorySize"))
820 {
821 out.insert((ven, dev), bytes);
822 }
823 }
824
825 Ok(out)
826 }
827
828 fn parse_matching_device_id(s: &str) -> Option<(u16, u16)> {
831 parse_ven_dev(s)
832 }
833
834 fn pci_key_from_bus_id(bus_id: &str) -> Option<(u16, u16)> {
836 let mut parts = bus_id.split(':');
837 let _domain = parts.next()?;
838 let ven = u16::from_str_radix(parts.next()?, 16).ok()?;
839 let dev_fn = parts.next()?;
840 let dev = dev_fn
841 .split('.')
842 .next()
843 .and_then(|h| u16::from_str_radix(h, 16).ok())?;
844 Some((ven, dev))
845 }
846
847 #[cfg(test)]
852 mod tests {
853 use super::*;
854
855 #[test]
856 fn parse_ven_dev_nvidia() {
857 let pnp = r"PCI\VEN_10DE&DEV_2204&SUBSYS_38811462&REV_A1\4&31DE5EF7&0&0008";
858 assert_eq!(parse_ven_dev(pnp), Some((0x10DE, 0x2204)));
859 }
860
861 #[test]
862 fn parse_ven_dev_amd() {
863 let pnp = r"PCI\VEN_1002&DEV_73A5&SUBSYS_E4571DA2&REV_C0\4&1A2B3C4D&0&0010";
864 assert_eq!(parse_ven_dev(pnp), Some((0x1002, 0x73A5)));
865 }
866
867 #[test]
868 fn parse_ven_dev_intel_lowercase() {
869 let pnp = r"pci\ven_8086&dev_9a49&subsys_00000000&rev_01\3&11583659&0&10";
870 assert_eq!(parse_ven_dev(pnp), Some((0x8086, 0x9A49)));
871 }
872
873 #[test]
874 fn parse_ven_dev_rejects_malformed() {
875 assert_eq!(parse_ven_dev("USB\\VID_1234&PID_5678"), None);
876 assert_eq!(parse_ven_dev(""), None);
877 }
878
879 #[test]
880 fn vendor_id_mapping() {
881 assert_eq!(vendor_from_pci_id(0x10DE), "nvidia");
882 assert_eq!(vendor_from_pci_id(0x1002), "amd");
883 assert_eq!(vendor_from_pci_id(0x1022), "amd");
884 assert_eq!(vendor_from_pci_id(0x8086), "intel");
885 assert_eq!(vendor_from_pci_id(0x1234), "unknown");
886 }
887
888 #[test]
889 fn canonicalize_strips_nvml_domain_padding() {
890 assert_eq!(canonicalize_pci_bus_id("00000000:01:00.0"), "0000:01:00.0");
891 assert_eq!(canonicalize_pci_bus_id("0000:17:00.0"), "0000:17:00.0");
892 assert_eq!(canonicalize_pci_bus_id("000a:17:00.0"), "000a:17:00.0");
893 }
894
895 #[test]
896 fn canonicalize_handles_missing_colon() {
897 assert_eq!(canonicalize_pci_bus_id("WEIRD"), "weird");
899 }
900
901 #[test]
902 fn pci_key_from_bus_id_roundtrip() {
903 let bus = pci_bus_id_from_pnp(
904 r"PCI\VEN_1002&DEV_73A5&SUBSYS_E4571DA2&REV_C0\4&1A2B3C4D&0&0010",
905 0x1002,
906 0x73A5,
907 );
908 assert_eq!(pci_key_from_bus_id(&bus), Some((0x1002, 0x73A5)));
909 }
910
911 #[test]
912 fn variant_u64_accepts_unsigned_widths() {
913 assert_eq!(variant_u64(Some(&Variant::UI4(4096))), Some(4096));
914 assert_eq!(
915 variant_u64(Some(&Variant::UI8(17_179_869_184))),
916 Some(17_179_869_184)
917 );
918 assert_eq!(variant_u64(Some(&Variant::UI1(7))), Some(7));
919 }
920
921 #[test]
922 fn variant_u64_rejects_negative_signed() {
923 assert_eq!(variant_u64(Some(&Variant::I4(-1))), None);
924 }
925
926 #[test]
927 fn variant_string_unwraps() {
928 assert_eq!(
929 variant_string(Some(&Variant::String("NVIDIA GeForce RTX 4090".into()))),
930 Some("NVIDIA GeForce RTX 4090".to_string())
931 );
932 assert_eq!(variant_string(Some(&Variant::UI4(7))), None);
933 assert_eq!(variant_string(None), None);
934 }
935 }
936}
937
938#[cfg(test)]
943mod tests {
944 use super::*;
945
946 #[test]
947 fn test_gpu_info_serialization_roundtrip() {
948 let info = GpuInfo {
949 pci_bus_id: "0000:01:00.0".to_string(),
950 vendor: "nvidia".to_string(),
951 model: "NVIDIA A100-SXM4-80GB".to_string(),
952 memory_mb: 81920,
953 device_path: "/dev/nvidia0".to_string(),
954 render_path: None,
955 };
956
957 let json = serde_json::to_string(&info).unwrap();
958 let deserialized: GpuInfo = serde_json::from_str(&json).unwrap();
959 assert_eq!(info, deserialized);
960 }
961
962 #[test]
963 fn test_gpu_info_amd_serialization() {
964 let info = GpuInfo {
965 pci_bus_id: "0000:03:00.0".to_string(),
966 vendor: "amd".to_string(),
967 model: "AMD GPU".to_string(),
968 memory_mb: 16384,
969 device_path: "/dev/dri/card0".to_string(),
970 render_path: Some("/dev/dri/renderD128".to_string()),
971 };
972
973 let json = serde_json::to_string(&info).unwrap();
974 let deserialized: GpuInfo = serde_json::from_str(&json).unwrap();
975 assert_eq!(info, deserialized);
976 }
977
978 #[test]
979 fn test_gpu_info_apple_serialization() {
980 let info = GpuInfo {
981 pci_bus_id: "apple:0".to_string(),
982 vendor: "apple".to_string(),
983 model: "Apple M2 Pro".to_string(),
984 memory_mb: 32768,
985 device_path: "iokit://AppleGPU/0".to_string(),
986 render_path: None,
987 };
988
989 let json = serde_json::to_string(&info).unwrap();
990 let deserialized: GpuInfo = serde_json::from_str(&json).unwrap();
991 assert_eq!(info, deserialized);
992 }
993
994 #[cfg(target_os = "linux")]
995 #[test]
996 fn test_find_device_paths_nvidia() {
997 let (dev, render) = find_device_paths("0000:01:00.0", "nvidia", 0);
998 assert_eq!(dev, "/dev/nvidia0");
999 assert!(render.is_none());
1000
1001 let (dev, render) = find_device_paths("0000:02:00.0", "nvidia", 1);
1002 assert_eq!(dev, "/dev/nvidia1");
1003 assert!(render.is_none());
1004 }
1005
1006 #[cfg(target_os = "linux")]
1007 #[test]
1008 fn test_find_device_paths_amd() {
1009 let (dev, render) = find_device_paths("0000:03:00.0", "amd", 0);
1010 assert_eq!(dev, "/dev/dri/card0");
1011 assert_eq!(render, Some("/dev/dri/renderD128".to_string()));
1012 }
1013
1014 #[cfg(target_os = "linux")]
1015 #[test]
1016 fn test_find_device_paths_intel() {
1017 let (dev, render) = find_device_paths("0000:00:02.0", "intel", 0);
1018 assert_eq!(dev, "/dev/dri/card0");
1019 assert_eq!(render, Some("/dev/dri/renderD128".to_string()));
1020 }
1021
1022 #[test]
1023 fn test_detect_gpus_returns_vec() {
1024 let gpus = detect_gpus();
1027 for gpu in &gpus {
1028 assert!(!gpu.pci_bus_id.is_empty());
1029 assert!(!gpu.vendor.is_empty());
1030 assert!(!gpu.model.is_empty());
1031 assert!(!gpu.device_path.is_empty());
1032 }
1033 }
1034}