1use serde::{Deserialize, Serialize};
22
23#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
25pub struct GpuInfo {
26 pub pci_bus_id: String,
28 pub vendor: String,
30 pub model: String,
32 pub memory_mb: u64,
34 pub device_path: String,
36 pub render_path: Option<String>,
38}
39
40#[cfg(target_os = "linux")]
53#[must_use]
54pub fn detect_gpus() -> Vec<GpuInfo> {
55 use std::path::Path;
56
57 let mut gpus = Vec::new();
58
59 let pci_dir = Path::new("/sys/bus/pci/devices");
60 if !pci_dir.exists() {
61 return gpus;
62 }
63
64 let Ok(entries) = std::fs::read_dir(pci_dir) else {
65 return gpus;
66 };
67
68 let nvidia_data = NvidiaSmiData::fetch();
70
71 for entry in entries.flatten() {
72 let device_dir = entry.path();
73
74 let class_path = device_dir.join("class");
76 let class = match std::fs::read_to_string(&class_path) {
77 Ok(c) => c.trim().to_string(),
78 Err(_) => continue,
79 };
80
81 if !class.starts_with("0x0302") && !class.starts_with("0x0300") {
83 continue;
84 }
85
86 let vendor_path = device_dir.join("vendor");
88 let vendor_id = std::fs::read_to_string(&vendor_path)
89 .unwrap_or_default()
90 .trim()
91 .to_string();
92
93 let vendor = match vendor_id.as_str() {
94 "0x10de" => "nvidia",
95 "0x1002" => "amd",
96 "0x8086" => "intel",
97 _ => "unknown",
98 }
99 .to_string();
100
101 let pci_bus_id = entry.file_name().to_string_lossy().to_string();
102
103 let vendor_index = gpus
105 .iter()
106 .filter(|g: &&GpuInfo| g.vendor == vendor)
107 .count();
108
109 let model = read_gpu_model(&device_dir, &vendor, &nvidia_data, vendor_index);
110 let memory_mb = read_gpu_memory(&device_dir, &vendor, &nvidia_data, vendor_index);
111 let (device_path, render_path) = find_device_paths(&pci_bus_id, &vendor, vendor_index);
112
113 gpus.push(GpuInfo {
114 pci_bus_id,
115 vendor,
116 model,
117 memory_mb,
118 device_path,
119 render_path,
120 });
121 }
122
123 gpus
124}
125
126#[cfg(target_os = "macos")]
137#[must_use]
138pub fn detect_gpus() -> Vec<GpuInfo> {
139 detect_apple_gpus()
140}
141
142#[cfg(target_os = "macos")]
144fn detect_apple_gpus() -> Vec<GpuInfo> {
145 let output = match std::process::Command::new("system_profiler")
146 .args(["SPDisplaysDataType", "-json"])
147 .output()
148 {
149 Ok(out) if out.status.success() => out,
150 _ => return Vec::new(),
151 };
152
153 let json_str = String::from_utf8_lossy(&output.stdout);
154 let parsed: serde_json::Value = match serde_json::from_str(&json_str) {
155 Ok(v) => v,
156 Err(_) => return Vec::new(),
157 };
158
159 let unified_memory_mb = detect_unified_memory_mb();
160
161 let mut gpus = Vec::new();
162
163 let Some(displays) = parsed.get("SPDisplaysDataType").and_then(|v| v.as_array()) else {
165 return gpus;
166 };
167
168 for (idx, display) in displays.iter().enumerate() {
169 let model = display
170 .get("sppci_model")
171 .and_then(|v| v.as_str())
172 .or_else(|| display.get("_name").and_then(|v| v.as_str()))
173 .unwrap_or("Apple GPU")
174 .to_string();
175
176 let chip_type = display
177 .get("sppci_chiptype")
178 .and_then(|v| v.as_str())
179 .unwrap_or("");
180
181 let model = if !chip_type.is_empty() && !model.contains(chip_type) {
183 format!("{model} ({chip_type})")
184 } else {
185 model
186 };
187
188 let memory_mb = display
192 .get("sppci_vram")
193 .and_then(|v| v.as_str())
194 .and_then(|s| {
195 let parts: Vec<&str> = s.split_whitespace().collect();
197 if parts.len() >= 2 {
198 let amount: u64 = parts[0].parse().ok()?;
199 match parts[1].to_uppercase().as_str() {
200 "GB" => Some(amount * 1024),
201 "MB" => Some(amount),
202 _ => None,
203 }
204 } else {
205 None
206 }
207 })
208 .unwrap_or(unified_memory_mb);
209
210 let vendor_str = display
211 .get("sppci_vendor")
212 .and_then(|v| v.as_str())
213 .unwrap_or("");
214
215 let vendor = if vendor_str.to_lowercase().contains("apple")
217 || chip_type.to_lowercase().starts_with("apple")
218 || model.to_lowercase().contains("apple m")
219 {
220 "apple".to_string()
221 } else if vendor_str.to_lowercase().contains("amd")
222 || vendor_str.to_lowercase().contains("ati")
223 {
224 "amd".to_string()
225 } else if vendor_str.to_lowercase().contains("intel") {
226 "intel".to_string()
227 } else {
228 "apple".to_string()
230 };
231
232 gpus.push(GpuInfo {
233 pci_bus_id: format!("apple:{idx}"),
234 vendor,
235 model,
236 memory_mb,
237 device_path: format!("iokit://AppleGPU/{idx}"),
238 render_path: None,
239 });
240 }
241
242 gpus
243}
244
245#[cfg(target_os = "macos")]
247fn detect_unified_memory_mb() -> u64 {
248 let output = match std::process::Command::new("sysctl")
249 .args(["-n", "hw.memsize"])
250 .output()
251 {
252 Ok(out) if out.status.success() => out,
253 _ => return 0,
254 };
255
256 let text = String::from_utf8_lossy(&output.stdout);
257 text.trim()
258 .parse::<u64>()
259 .map(|bytes| bytes / (1024 * 1024))
260 .unwrap_or(0)
261}
262
263#[cfg(target_os = "windows")]
283#[must_use]
284pub fn detect_gpus() -> Vec<GpuInfo> {
285 windows_impl::detect_gpus_windows()
286}
287
288#[cfg(not(any(target_os = "linux", target_os = "macos", target_os = "windows")))]
294#[must_use]
295pub fn detect_gpus() -> Vec<GpuInfo> {
296 Vec::new()
297}
298
299#[cfg(target_os = "linux")]
305struct NvidiaSmiData {
306 names: Vec<String>,
308 memories: Vec<u64>,
310}
311
312#[cfg(target_os = "linux")]
313impl NvidiaSmiData {
314 fn fetch() -> Self {
316 let names = Self::query("name");
317 let memories = Self::query("memory.total")
318 .iter()
319 .map(|s| s.trim().parse::<u64>().unwrap_or(0))
320 .collect();
321
322 Self { names, memories }
323 }
324
325 fn query(field: &str) -> Vec<String> {
326 let output = std::process::Command::new("nvidia-smi")
327 .args([
328 &format!("--query-gpu={field}"),
329 "--format=csv,noheader,nounits",
330 ])
331 .output();
332
333 match output {
334 Ok(out) if out.status.success() => {
335 let text = String::from_utf8_lossy(&out.stdout);
336 text.lines().map(|l| l.trim().to_string()).collect()
337 }
338 _ => Vec::new(),
339 }
340 }
341}
342
343#[cfg(target_os = "linux")]
349fn read_gpu_model(
350 device_dir: &std::path::Path,
351 vendor: &str,
352 nvidia_data: &NvidiaSmiData,
353 vendor_index: usize,
354) -> String {
355 if let Some(name) = read_drm_product_name(device_dir) {
357 return name;
358 }
359
360 match vendor {
361 "nvidia" => {
362 if let Some(name) = nvidia_data.names.get(vendor_index) {
364 if !name.is_empty() {
365 return name.clone();
366 }
367 }
368 "NVIDIA GPU".to_string()
369 }
370 "amd" => "AMD GPU".to_string(),
371 "intel" => "Intel GPU".to_string(),
372 _ => "Unknown GPU".to_string(),
373 }
374}
375
376#[cfg(target_os = "linux")]
380fn read_drm_product_name(device_dir: &std::path::Path) -> Option<String> {
381 let product_name_path = device_dir.join("label");
383 if let Ok(name) = std::fs::read_to_string(&product_name_path) {
384 let name = name.trim().to_string();
385 if !name.is_empty() {
386 return Some(name);
387 }
388 }
389
390 let drm_dir = device_dir.join("drm");
392 if let Ok(entries) = std::fs::read_dir(&drm_dir) {
393 for entry in entries.flatten() {
394 let name = entry.file_name();
395 let name_str = name.to_string_lossy();
396 if name_str.starts_with("card") {
397 let product_path = entry.path().join("device").join("product_name");
398 if let Ok(product) = std::fs::read_to_string(&product_path) {
399 let product = product.trim().to_string();
400 if !product.is_empty() {
401 return Some(product);
402 }
403 }
404 }
405 }
406 }
407
408 None
409}
410
411#[cfg(target_os = "linux")]
417fn read_gpu_memory(
418 device_dir: &std::path::Path,
419 vendor: &str,
420 nvidia_data: &NvidiaSmiData,
421 vendor_index: usize,
422) -> u64 {
423 if vendor == "nvidia" {
425 if let Some(&mem) = nvidia_data.memories.get(vendor_index) {
426 if mem > 0 {
427 return mem;
428 }
429 }
430 }
431
432 if vendor == "amd" {
434 let vram_path = device_dir.join("mem_info_vram_total");
435 if let Ok(content) = std::fs::read_to_string(&vram_path) {
436 if let Ok(bytes) = content.trim().parse::<u64>() {
437 return bytes / (1024 * 1024);
438 }
439 }
440 }
441
442 let resource_path = device_dir.join("resource");
445 if let Ok(content) = std::fs::read_to_string(&resource_path) {
446 let mut max_size: u64 = 0;
447 for line in content.lines() {
448 let parts: Vec<&str> = line.split_whitespace().collect();
449 if parts.len() >= 2 {
450 if let (Ok(start), Ok(end)) = (
451 u64::from_str_radix(parts[0].trim_start_matches("0x"), 16),
452 u64::from_str_radix(parts[1].trim_start_matches("0x"), 16),
453 ) {
454 if end > start {
455 let size = end - start + 1;
456 if size > max_size {
457 max_size = size;
458 }
459 }
460 }
461 }
462 }
463 if max_size > 0 {
464 return max_size / (1024 * 1024);
465 }
466 }
467
468 0
469}
470
471#[cfg(target_os = "linux")]
477fn find_device_paths(
478 _pci_bus_id: &str,
479 vendor: &str,
480 vendor_index: usize,
481) -> (String, Option<String>) {
482 if vendor == "nvidia" {
483 let dev = format!("/dev/nvidia{vendor_index}");
484 (dev, None)
485 } else {
486 let card = format!("/dev/dri/card{vendor_index}");
488 let render = format!("/dev/dri/renderD{}", 128 + vendor_index);
489 (card, Some(render))
490 }
491}
492
493#[cfg(target_os = "windows")]
498mod windows_impl {
499 use super::GpuInfo;
500 use std::collections::HashMap;
501 use wmi::{Variant, WMIConnection};
502
503 const DISPLAY_CLASS_GUID: &str = "{4d36e968-e325-11ce-bfc1-08002be10318}";
509
510 pub fn detect_gpus_windows() -> Vec<GpuInfo> {
512 let mut gpus: Vec<GpuInfo> = Vec::new();
513
514 let nvml_gpus = detect_via_nvml();
516 gpus.extend(nvml_gpus);
517
518 let wmi_gpus = match detect_via_wmi() {
520 Ok(v) => v,
521 Err(e) => {
522 tracing::warn!(
523 "WMI Win32_VideoController query failed: {e}; \
524 Windows GPU detection falling back to NVML-only"
525 );
526 Vec::new()
527 }
528 };
529
530 let amd_registry = collect_amd_registry_vram().unwrap_or_default();
532
533 for mut wmi_gpu in wmi_gpus {
534 if wmi_gpu.vendor == "amd" {
536 if let Some(key) = pci_key_from_bus_id(&wmi_gpu.pci_bus_id) {
537 if let Some(&vram_bytes) = amd_registry.get(&key) {
538 wmi_gpu.memory_mb = vram_bytes / (1024 * 1024);
539 }
540 }
541 }
542
543 if let Some(existing) = gpus.iter_mut().find(|g| g.pci_bus_id == wmi_gpu.pci_bus_id) {
545 if existing.model.trim().is_empty() || existing.model == "NVIDIA GPU" {
548 wmi_gpu.model.clone_into(&mut existing.model);
549 }
550 continue;
551 }
552
553 gpus.push(wmi_gpu);
554 }
555
556 gpus
557 }
558
559 fn detect_via_nvml() -> Vec<GpuInfo> {
564 let nvml = match nvml_wrapper::Nvml::init() {
568 Ok(n) => n,
569 Err(e) => {
570 tracing::debug!("NVML unavailable (no NVIDIA driver?): {e}");
571 return Vec::new();
572 }
573 };
574
575 let count = match nvml.device_count() {
576 Ok(c) => c,
577 Err(e) => {
578 tracing::debug!("nvmlDeviceGetCount failed: {e}");
579 return Vec::new();
580 }
581 };
582
583 let mut out = Vec::with_capacity(count as usize);
584 for i in 0..count {
585 let device = match nvml.device_by_index(i) {
586 Ok(d) => d,
587 Err(e) => {
588 tracing::debug!("nvmlDeviceGetHandleByIndex({i}) failed: {e}");
589 continue;
590 }
591 };
592
593 let model = device.name().unwrap_or_else(|_| "NVIDIA GPU".to_string());
594
595 let memory_mb = device
597 .memory_info()
598 .map(|m| m.total / (1024 * 1024))
599 .unwrap_or(0);
600
601 let pci_bus_id = match device.pci_info() {
602 Ok(info) => canonicalize_pci_bus_id(&info.bus_id),
603 Err(e) => {
604 tracing::debug!("nvmlDeviceGetPciInfo({i}) failed: {e}");
605 format!("nvml:{i}")
608 }
609 };
610
611 out.push(GpuInfo {
612 pci_bus_id,
613 vendor: "nvidia".to_string(),
614 model,
615 memory_mb,
616 device_path: format!("nvml://{i}"),
619 render_path: None,
620 });
621 }
622
623 out
624 }
625
626 fn canonicalize_pci_bus_id(raw: &str) -> String {
631 let mut parts = raw.splitn(2, ':');
632 let (Some(domain), Some(rest)) = (parts.next(), parts.next()) else {
633 return raw.to_ascii_lowercase();
634 };
635 let trimmed_domain = domain.trim_start_matches('0').to_string();
636 let domain_out = if trimmed_domain.is_empty() {
637 "0000".to_string()
638 } else if trimmed_domain.len() < 4 {
639 format!("{trimmed_domain:0>4}")
640 } else {
641 trimmed_domain
642 };
643 format!("{domain_out}:{rest}").to_ascii_lowercase()
644 }
645
646 fn detect_via_wmi() -> Result<Vec<GpuInfo>, String> {
651 let wmi = WMIConnection::new().map_err(|e| format!("WMIConnection::new: {e}"))?;
656
657 let query = "SELECT Name, PNPDeviceID, AdapterRAM \
661 FROM Win32_VideoController \
662 WHERE PNPDeviceID LIKE 'PCI\\\\VEN_%'";
663
664 let rows: Vec<HashMap<String, Variant>> = wmi
665 .raw_query(query)
666 .map_err(|e| format!("raw_query({query}): {e}"))?;
667
668 let mut out = Vec::with_capacity(rows.len());
669 for row in rows {
670 let Some(pnp) = variant_string(row.get("PNPDeviceID")) else {
671 continue;
672 };
673 let Some((vendor_id, device_id)) = parse_ven_dev(&pnp) else {
674 continue;
675 };
676
677 let vendor = vendor_from_pci_id(vendor_id);
678 let model = variant_string(row.get("Name"))
679 .unwrap_or_else(|| format!("{} GPU", vendor.to_ascii_uppercase()));
680
681 let memory_mb = variant_u64(row.get("AdapterRAM")).map_or(0, |b| b / (1024 * 1024));
685
686 let pci_bus_id = pci_bus_id_from_pnp(&pnp, vendor_id, device_id);
687
688 out.push(GpuInfo {
689 pci_bus_id,
690 vendor: vendor.to_string(),
691 model,
692 memory_mb,
693 device_path: format!(r"\\.\DISPLAY#{pnp}"),
694 render_path: None,
695 });
696 }
697
698 Ok(out)
699 }
700
701 fn parse_ven_dev(pnp: &str) -> Option<(u16, u16)> {
704 let upper = pnp.to_ascii_uppercase();
707 let ven = extract_hex(&upper, "VEN_", 4)?;
708 let dev = extract_hex(&upper, "DEV_", 4)?;
709 Some((ven, dev))
710 }
711
712 fn extract_hex(s: &str, marker: &str, nibbles: usize) -> Option<u16> {
713 let start = s.find(marker)? + marker.len();
714 let hex = s.get(start..start + nibbles)?;
715 u16::from_str_radix(hex, 16).ok()
716 }
717
718 fn vendor_from_pci_id(vendor_id: u16) -> &'static str {
719 match vendor_id {
720 0x10DE => "nvidia",
721 0x1002 | 0x1022 => "amd",
723 0x8086 => "intel",
724 _ => "unknown",
725 }
726 }
727
728 fn pci_bus_id_from_pnp(pnp: &str, vendor_id: u16, device_id: u16) -> String {
739 let slot = pnp
743 .rsplit_once('&')
744 .and_then(|(_, tail)| tail.chars().take(4).collect::<String>().parse::<u16>().ok())
745 .unwrap_or(0);
746 format!("0000:{vendor_id:04x}:{device_id:04x}.{slot:x}")
747 }
748
749 fn variant_string(v: Option<&Variant>) -> Option<String> {
750 match v? {
751 Variant::String(s) => Some(s.clone()),
752 _ => None,
753 }
754 }
755
756 fn variant_u64(v: Option<&Variant>) -> Option<u64> {
757 match v? {
758 Variant::UI1(n) => Some(u64::from(*n)),
759 Variant::UI2(n) => Some(u64::from(*n)),
760 Variant::UI4(n) => Some(u64::from(*n)),
761 Variant::UI8(n) => Some(*n),
762 Variant::I1(n) if *n >= 0 => Some(u64::try_from(*n).unwrap_or(0)),
763 Variant::I2(n) if *n >= 0 => Some(u64::try_from(*n).unwrap_or(0)),
764 Variant::I4(n) if *n >= 0 => Some(u64::try_from(*n).unwrap_or(0)),
765 Variant::I8(n) if *n >= 0 => Some(u64::try_from(*n).unwrap_or(0)),
766 _ => None,
767 }
768 }
769
770 fn collect_amd_registry_vram() -> Result<HashMap<(u16, u16), u64>, String> {
778 let class_path = format!(r"SYSTEM\CurrentControlSet\Control\Class\{DISPLAY_CLASS_GUID}");
779 let class_key = windows_registry::LOCAL_MACHINE
780 .open(&class_path)
781 .map_err(|e| format!("open HKLM\\{class_path}: {e}"))?;
782
783 let mut out: HashMap<(u16, u16), u64> = HashMap::new();
784
785 let subkeys = match class_key.keys() {
789 Ok(it) => it,
790 Err(e) => return Err(format!("enumerate class subkeys: {e}")),
791 };
792
793 for name in subkeys {
794 if name.len() != 4 || !name.chars().all(|c| c.is_ascii_digit()) {
797 continue;
798 }
799
800 let Ok(adapter_key) = class_key.open(&name) else {
801 continue;
802 };
803
804 let vendor_id = adapter_key
805 .get_string("MatchingDeviceId")
806 .ok()
807 .as_deref()
808 .and_then(parse_matching_device_id);
809 let Some((ven, dev)) = vendor_id else {
810 continue;
811 };
812
813 if vendor_from_pci_id(ven) != "amd" {
814 continue;
815 }
816
817 if let Ok(bytes) = adapter_key
822 .open("HardwareInformation")
823 .and_then(|hw| hw.get_u64("qwMemorySize"))
824 {
825 out.insert((ven, dev), bytes);
826 }
827 }
828
829 Ok(out)
830 }
831
832 fn parse_matching_device_id(s: &str) -> Option<(u16, u16)> {
835 parse_ven_dev(s)
836 }
837
838 fn pci_key_from_bus_id(bus_id: &str) -> Option<(u16, u16)> {
840 let mut parts = bus_id.split(':');
841 let _domain = parts.next()?;
842 let ven = u16::from_str_radix(parts.next()?, 16).ok()?;
843 let dev_fn = parts.next()?;
844 let dev = dev_fn
845 .split('.')
846 .next()
847 .and_then(|h| u16::from_str_radix(h, 16).ok())?;
848 Some((ven, dev))
849 }
850
851 #[cfg(test)]
856 mod tests {
857 use super::*;
858
859 #[test]
860 fn parse_ven_dev_nvidia() {
861 let pnp = r"PCI\VEN_10DE&DEV_2204&SUBSYS_38811462&REV_A1\4&31DE5EF7&0&0008";
862 assert_eq!(parse_ven_dev(pnp), Some((0x10DE, 0x2204)));
863 }
864
865 #[test]
866 fn parse_ven_dev_amd() {
867 let pnp = r"PCI\VEN_1002&DEV_73A5&SUBSYS_E4571DA2&REV_C0\4&1A2B3C4D&0&0010";
868 assert_eq!(parse_ven_dev(pnp), Some((0x1002, 0x73A5)));
869 }
870
871 #[test]
872 fn parse_ven_dev_intel_lowercase() {
873 let pnp = r"pci\ven_8086&dev_9a49&subsys_00000000&rev_01\3&11583659&0&10";
874 assert_eq!(parse_ven_dev(pnp), Some((0x8086, 0x9A49)));
875 }
876
877 #[test]
878 fn parse_ven_dev_rejects_malformed() {
879 assert_eq!(parse_ven_dev("USB\\VID_1234&PID_5678"), None);
880 assert_eq!(parse_ven_dev(""), None);
881 }
882
883 #[test]
884 fn vendor_id_mapping() {
885 assert_eq!(vendor_from_pci_id(0x10DE), "nvidia");
886 assert_eq!(vendor_from_pci_id(0x1002), "amd");
887 assert_eq!(vendor_from_pci_id(0x1022), "amd");
888 assert_eq!(vendor_from_pci_id(0x8086), "intel");
889 assert_eq!(vendor_from_pci_id(0x1234), "unknown");
890 }
891
892 #[test]
893 fn canonicalize_strips_nvml_domain_padding() {
894 assert_eq!(canonicalize_pci_bus_id("00000000:01:00.0"), "0000:01:00.0");
895 assert_eq!(canonicalize_pci_bus_id("0000:17:00.0"), "0000:17:00.0");
896 assert_eq!(canonicalize_pci_bus_id("000a:17:00.0"), "000a:17:00.0");
897 }
898
899 #[test]
900 fn canonicalize_handles_missing_colon() {
901 assert_eq!(canonicalize_pci_bus_id("WEIRD"), "weird");
903 }
904
905 #[test]
906 fn pci_key_from_bus_id_roundtrip() {
907 let bus = pci_bus_id_from_pnp(
908 r"PCI\VEN_1002&DEV_73A5&SUBSYS_E4571DA2&REV_C0\4&1A2B3C4D&0&0010",
909 0x1002,
910 0x73A5,
911 );
912 assert_eq!(pci_key_from_bus_id(&bus), Some((0x1002, 0x73A5)));
913 }
914
915 #[test]
916 fn variant_u64_accepts_unsigned_widths() {
917 assert_eq!(variant_u64(Some(&Variant::UI4(4096))), Some(4096));
918 assert_eq!(
919 variant_u64(Some(&Variant::UI8(17_179_869_184))),
920 Some(17_179_869_184)
921 );
922 assert_eq!(variant_u64(Some(&Variant::UI1(7))), Some(7));
923 }
924
925 #[test]
926 fn variant_u64_rejects_negative_signed() {
927 assert_eq!(variant_u64(Some(&Variant::I4(-1))), None);
928 }
929
930 #[test]
931 fn variant_string_unwraps() {
932 assert_eq!(
933 variant_string(Some(&Variant::String("NVIDIA GeForce RTX 4090".into()))),
934 Some("NVIDIA GeForce RTX 4090".to_string())
935 );
936 assert_eq!(variant_string(Some(&Variant::UI4(7))), None);
937 assert_eq!(variant_string(None), None);
938 }
939 }
940}
941
942#[cfg(test)]
947mod tests {
948 use super::*;
949
950 #[test]
951 fn test_gpu_info_serialization_roundtrip() {
952 let info = GpuInfo {
953 pci_bus_id: "0000:01:00.0".to_string(),
954 vendor: "nvidia".to_string(),
955 model: "NVIDIA A100-SXM4-80GB".to_string(),
956 memory_mb: 81920,
957 device_path: "/dev/nvidia0".to_string(),
958 render_path: None,
959 };
960
961 let json = serde_json::to_string(&info).unwrap();
962 let deserialized: GpuInfo = serde_json::from_str(&json).unwrap();
963 assert_eq!(info, deserialized);
964 }
965
966 #[test]
967 fn test_gpu_info_amd_serialization() {
968 let info = GpuInfo {
969 pci_bus_id: "0000:03:00.0".to_string(),
970 vendor: "amd".to_string(),
971 model: "AMD GPU".to_string(),
972 memory_mb: 16384,
973 device_path: "/dev/dri/card0".to_string(),
974 render_path: Some("/dev/dri/renderD128".to_string()),
975 };
976
977 let json = serde_json::to_string(&info).unwrap();
978 let deserialized: GpuInfo = serde_json::from_str(&json).unwrap();
979 assert_eq!(info, deserialized);
980 }
981
982 #[test]
983 fn test_gpu_info_apple_serialization() {
984 let info = GpuInfo {
985 pci_bus_id: "apple:0".to_string(),
986 vendor: "apple".to_string(),
987 model: "Apple M2 Pro".to_string(),
988 memory_mb: 32768,
989 device_path: "iokit://AppleGPU/0".to_string(),
990 render_path: None,
991 };
992
993 let json = serde_json::to_string(&info).unwrap();
994 let deserialized: GpuInfo = serde_json::from_str(&json).unwrap();
995 assert_eq!(info, deserialized);
996 }
997
998 #[cfg(target_os = "linux")]
999 #[test]
1000 fn test_find_device_paths_nvidia() {
1001 let (dev, render) = find_device_paths("0000:01:00.0", "nvidia", 0);
1002 assert_eq!(dev, "/dev/nvidia0");
1003 assert!(render.is_none());
1004
1005 let (dev, render) = find_device_paths("0000:02:00.0", "nvidia", 1);
1006 assert_eq!(dev, "/dev/nvidia1");
1007 assert!(render.is_none());
1008 }
1009
1010 #[cfg(target_os = "linux")]
1011 #[test]
1012 fn test_find_device_paths_amd() {
1013 let (dev, render) = find_device_paths("0000:03:00.0", "amd", 0);
1014 assert_eq!(dev, "/dev/dri/card0");
1015 assert_eq!(render, Some("/dev/dri/renderD128".to_string()));
1016 }
1017
1018 #[cfg(target_os = "linux")]
1019 #[test]
1020 fn test_find_device_paths_intel() {
1021 let (dev, render) = find_device_paths("0000:00:02.0", "intel", 0);
1022 assert_eq!(dev, "/dev/dri/card0");
1023 assert_eq!(render, Some("/dev/dri/renderD128".to_string()));
1024 }
1025
1026 #[test]
1027 fn test_detect_gpus_returns_vec() {
1028 let gpus = detect_gpus();
1031 for gpu in &gpus {
1032 assert!(!gpu.pci_bus_id.is_empty());
1033 assert!(!gpu.vendor.is_empty());
1034 assert!(!gpu.model.is_empty());
1035 assert!(!gpu.device_path.is_empty());
1036 }
1037 }
1038}