scirs2_core/gpu/backends/
mod.rs1use crate::gpu::{GpuBackend, GpuError};
7use std::process::Command;
8
9#[cfg(all(target_os = "macos", feature = "serialization"))]
10use serde_json;
11
12#[cfg(feature = "validation")]
13use regex::Regex;
14
15#[cfg(feature = "cuda")]
17pub mod cuda;
18
19#[cfg(feature = "opencl")]
20pub mod opencl;
21
22#[cfg(feature = "wgpu_backend")]
23pub mod wgpu;
24
25#[cfg(all(feature = "metal", target_os = "macos"))]
26pub mod metal;
27
28#[cfg(all(feature = "metal", target_os = "macos"))]
29pub mod metal_mps;
30
31#[cfg(all(feature = "metal", target_os = "macos"))]
36pub mod msl_kernels;
37
38#[cfg(all(feature = "mpsgraph", target_os = "macos"))]
39pub mod metal_mpsgraph;
40
41#[cfg(feature = "cuda")]
43pub use cuda::{get_optimizer_kernels, CudaContext, CudaStream};
44
45#[cfg(feature = "opencl")]
46pub use opencl::OpenCLContext;
47
48#[cfg(feature = "wgpu_backend")]
49pub use wgpu::WebGPUContext;
50
51#[cfg(all(feature = "metal", target_os = "macos"))]
52pub use metal::{MetalBufferOptions, MetalContext, MetalStorageMode};
53
54#[cfg(all(feature = "metal", target_os = "macos"))]
55pub use metal_mps::{MPSContext, MPSDataType, MPSOperations};
56
57#[cfg(all(feature = "mpsgraph", target_os = "macos"))]
58pub use metal_mpsgraph::MPSGraphContext;
59
60#[derive(Debug, Clone)]
62pub struct GpuInfo {
63 pub backend: GpuBackend,
65 pub device_name: String,
67 pub memory_bytes: Option<u64>,
69 pub compute_capability: Option<String>,
71 pub supports_tensors: bool,
73}
74
75#[derive(Debug, Clone)]
77pub struct GpuDetectionResult {
78 pub devices: Vec<GpuInfo>,
80 pub recommended_backend: GpuBackend,
82}
83
84#[allow(dead_code)]
86pub fn detect_gpu_backends() -> GpuDetectionResult {
87 let mut devices = Vec::new();
88
89 #[cfg(not(test))]
91 {
92 if let Ok(cuda_devices) = detect_cuda_devices() {
94 devices.extend(cuda_devices);
95 }
96
97 if let Ok(rocm_devices) = detect_rocm_devices() {
99 devices.extend(rocm_devices);
100 }
101
102 #[cfg(target_os = "macos")]
104 if let Ok(metal_devices) = detect_metal_devices() {
105 devices.extend(metal_devices);
106 }
107
108 if let Ok(opencl_devices) = detect_opencl_devices() {
110 devices.extend(opencl_devices);
111 }
112 }
113
114 let recommended_backend = if devices
116 .iter()
117 .any(|d: &GpuInfo| d.backend == GpuBackend::Cuda)
118 {
119 GpuBackend::Cuda
120 } else if devices
121 .iter()
122 .any(|d: &GpuInfo| d.backend == GpuBackend::Rocm)
123 {
124 GpuBackend::Rocm
125 } else if devices
126 .iter()
127 .any(|d: &GpuInfo| d.backend == GpuBackend::Metal)
128 {
129 GpuBackend::Metal
130 } else if devices
131 .iter()
132 .any(|d: &GpuInfo| d.backend == GpuBackend::OpenCL)
133 {
134 GpuBackend::OpenCL
135 } else {
136 GpuBackend::Cpu
137 };
138
139 devices.push(GpuInfo {
141 backend: GpuBackend::Cpu,
142 device_name: "CPU".to_string(),
143 memory_bytes: None,
144 compute_capability: None,
145 supports_tensors: false,
146 });
147
148 GpuDetectionResult {
149 devices,
150 recommended_backend,
151 }
152}
153
154#[allow(dead_code)]
156fn detect_rocm_devices() -> Result<Vec<GpuInfo>, GpuError> {
157 let mut devices = Vec::new();
158
159 match Command::new("rocm-smi")
161 .arg("--showproductname")
162 .arg("--showmeminfo")
163 .arg("vram")
164 .arg("--csv")
165 .output()
166 {
167 Ok(output) if output.status.success() => {
168 let output_str = String::from_utf8_lossy(&output.stdout);
169
170 for line in output_str.lines().skip(1) {
171 if line.trim().is_empty() {
173 continue;
174 }
175
176 let parts: Vec<&str> = line.split(',').map(|s| s.trim()).collect();
177 if parts.len() >= 3 {
178 let device_name = parts[1].trim_matches('"').to_string();
179 let memory_str = parts[2].trim_matches('"');
180
181 let memory_mb = memory_str
183 .split_whitespace()
184 .next()
185 .and_then(|s| s.parse::<u64>().ok())
186 .unwrap_or(0)
187 * 1024
188 * 1024; devices.push(GpuInfo {
191 backend: GpuBackend::Rocm,
192 device_name,
193 memory_bytes: Some(memory_mb),
194 compute_capability: Some("RDNA/CDNA".to_string()),
195 supports_tensors: true, });
197 }
198 }
199 }
200 _ => {
201 }
207 }
208
209 if devices.is_empty() {
210 Err(GpuError::BackendNotAvailable("ROCm".to_string()))
211 } else {
212 Ok(devices)
213 }
214}
215
216#[allow(dead_code)]
218fn detect_cuda_devices() -> Result<Vec<GpuInfo>, GpuError> {
219 let mut devices = Vec::new();
220
221 match Command::new("nvidia-smi")
223 .arg("--query-gpu=name,memory.total,compute_cap")
224 .arg("--format=csv,noheader,nounits")
225 .output()
226 {
227 Ok(output) if output.status.success() => {
228 let output_str = String::from_utf8_lossy(&output.stdout);
229
230 for line in output_str.lines() {
231 if line.trim().is_empty() {
232 continue;
233 }
234
235 let parts: Vec<&str> = line.split(',').map(|s| s.trim()).collect();
236 if parts.len() >= 3 {
237 let device_name = parts[0].to_string();
238 let memory_mb = parts[1].parse::<u64>().unwrap_or(0) * 1024 * 1024; let compute_capability = parts[2].to_string();
240
241 let supports_tensors =
243 if let Some(major_str) = compute_capability.split('.').next() {
244 major_str.parse::<u32>().unwrap_or(0) >= 7 } else {
246 false
247 };
248
249 devices.push(GpuInfo {
250 backend: GpuBackend::Cuda,
251 device_name,
252 memory_bytes: Some(memory_mb),
253 compute_capability: Some(compute_capability),
254 supports_tensors,
255 });
256 }
257 }
258 }
259 _ => {
260 }
266 }
267
268 if devices.is_empty() {
269 Err(GpuError::BackendNotAvailable("CUDA".to_string()))
270 } else {
271 Ok(devices)
272 }
273}
274
275#[cfg(target_os = "macos")]
277#[allow(dead_code)]
278fn detect_metal_devices() -> Result<Vec<GpuInfo>, GpuError> {
279 let mut devices = Vec::new();
280
281 match Command::new("system_profiler")
283 .arg("SPDisplaysDataType")
284 .arg("-json")
285 .output()
286 {
287 Ok(output) if output.status.success() => {
288 #[cfg(feature = "serialization")]
290 {
291 use std::str::FromStr;
292 let output_str = String::from_utf8_lossy(&output.stdout);
293
294 if let Ok(json_value) = serde_json::Value::from_str(&output_str) {
295 if let Some(displays) = json_value
296 .get("SPDisplaysDataType")
297 .and_then(|v: &serde_json::Value| v.as_array())
298 {
299 #[cfg(feature = "validation")]
301 let vram_regex = Regex::new(r"(\d+)\s*(GB|MB)").ok();
302
303 for display in displays {
304 if let Some(model) = display
306 .get("sppci_model")
307 .and_then(|v: &serde_json::Value| v.as_str())
308 {
309 let mut gpu_info = GpuInfo {
310 backend: GpuBackend::Metal,
311 device_name: model.to_string(),
312 memory_bytes: None,
313 compute_capability: None,
314 supports_tensors: true,
315 };
316
317 if let Some(vram_str) = display
319 .get("vram_pcie")
320 .and_then(|v: &serde_json::Value| v.as_str())
321 .or_else(|| {
322 display
323 .get("vram")
324 .and_then(|v: &serde_json::Value| v.as_str())
325 })
326 {
327 #[cfg(feature = "validation")]
329 if let Some(captures) =
330 vram_regex.as_ref().and_then(|re| re.captures(vram_str))
331 {
332 if let (Some(value), Some(unit)) =
333 (captures.get(1), captures.get(2))
334 {
335 if let Ok(num) = u64::from_str(value.as_str()) {
336 gpu_info.memory_bytes = Some(match unit.as_str() {
337 "GB" => num * 1024 * 1024 * 1024,
338 "MB" => num * 1024 * 1024,
339 _ => 0,
340 });
341 }
342 }
343 }
344 }
345
346 if let Some(metal_family) = display
348 .get("sppci_metal_family")
349 .and_then(|v: &serde_json::Value| v.as_str())
350 {
351 gpu_info.compute_capability = Some(metal_family.to_string());
352 }
353
354 devices.push(gpu_info);
355 }
356 }
357 }
358 }
359 }
360
361 if devices.is_empty() {
363 #[cfg(feature = "metal")]
365 {
366 use metal::Device;
367 if let Some(device) = Device::system_default() {
368 let name = device.name().to_string();
369 let mut gpu_info = GpuInfo {
370 backend: GpuBackend::Metal,
371 device_name: name.clone(),
372 memory_bytes: None,
373 compute_capability: None,
374 supports_tensors: true,
375 };
376
377 gpu_info.compute_capability = Some("Metal GPU".to_string());
380
381 devices.push(gpu_info);
382 }
383 }
384
385 #[cfg(not(feature = "metal"))]
387 {
388 devices.push(GpuInfo {
389 backend: GpuBackend::Metal,
390 device_name: "Metal GPU".to_string(),
391 memory_bytes: None,
392 compute_capability: None,
393 supports_tensors: true,
394 });
395 }
396 }
397 }
398 _ => {
399 #[cfg(feature = "metal")]
401 {
402 use metal::Device;
403 if let Some(device) = Device::system_default() {
404 devices.push(GpuInfo {
405 backend: GpuBackend::Metal,
406 device_name: device.name().to_string(),
407 memory_bytes: None,
408 compute_capability: None,
409 supports_tensors: true,
410 });
411 } else {
412 return Err(GpuError::BackendNotAvailable("Metal".to_string()));
413 }
414 }
415
416 #[cfg(not(feature = "metal"))]
417 {
418 return Err(GpuError::BackendNotAvailable("Metal".to_string()));
419 }
420 }
421 }
422
423 if devices.is_empty() {
424 Err(GpuError::BackendNotAvailable("Metal".to_string()))
425 } else {
426 Ok(devices)
427 }
428}
429
430#[cfg(not(target_os = "macos"))]
432#[allow(dead_code)]
433fn detect_metal_devices() -> Result<Vec<GpuInfo>, GpuError> {
434 Err(GpuError::BackendNotAvailable(
435 "Metal (not macOS)".to_string(),
436 ))
437}
438
439#[allow(dead_code)]
441fn detect_opencl_devices() -> Result<Vec<GpuInfo>, GpuError> {
442 let mut devices = Vec::new();
443
444 match Command::new("clinfo").arg("--list").output() {
446 Ok(output) if output.status.success() => {
447 let output_str = String::from_utf8_lossy(&output.stdout);
448
449 for line in output_str.lines() {
450 if line.trim().starts_with("Platform") || line.trim().starts_with("Device") {
451 devices.push(GpuInfo {
454 backend: GpuBackend::OpenCL,
455 device_name: "OpenCL Device".to_string(),
456 memory_bytes: None,
457 compute_capability: None,
458 supports_tensors: false,
459 });
460 break; }
462 }
463 }
464 _ => {
465 return Err(GpuError::BackendNotAvailable("OpenCL".to_string()));
466 }
467 }
468
469 if devices.is_empty() {
470 Err(GpuError::BackendNotAvailable("OpenCL".to_string()))
471 } else {
472 Ok(devices)
473 }
474}
475
476#[allow(dead_code)]
478pub fn check_backend_installation(backend: GpuBackend) -> Result<bool, GpuError> {
479 match backend {
480 GpuBackend::Cuda => {
481 match Command::new("nvcc").arg("--version").output() {
483 Ok(output) if output.status.success() => Ok(true),
484 _ => Ok(false),
485 }
486 }
487 GpuBackend::Rocm => {
488 match Command::new("hipcc").arg("--version").output() {
490 Ok(output) if output.status.success() => Ok(true),
491 _ => {
492 match Command::new("rocm-smi").arg("--version").output() {
494 Ok(output) if output.status.success() => Ok(true),
495 _ => Ok(false),
496 }
497 }
498 }
499 }
500 GpuBackend::Metal => {
501 #[cfg(target_os = "macos")]
502 {
503 Ok(true)
505 }
506 #[cfg(not(target_os = "macos"))]
507 {
508 Ok(false)
509 }
510 }
511 GpuBackend::OpenCL => {
512 match Command::new("clinfo").output() {
514 Ok(output) if output.status.success() => Ok(true),
515 _ => Ok(false),
516 }
517 }
518 GpuBackend::Wgpu => {
519 Ok(true)
521 }
522 GpuBackend::Cpu => Ok(true),
523 }
524}
525
526#[allow(dead_code)]
528pub fn get_device_info(backend: GpuBackend, device_id: usize) -> Result<GpuInfo, GpuError> {
529 let detection_result = detect_gpu_backends();
530
531 detection_result
532 .devices
533 .into_iter()
534 .filter(|d| d.backend == backend)
535 .nth(device_id)
536 .ok_or_else(|| {
537 GpuError::InvalidParameter(format!(
538 "Device {device_id} not found for backend {:?}",
539 backend
540 ))
541 })
542}
543
544#[allow(dead_code)]
546pub fn initialize_optimal_backend() -> Result<GpuBackend, GpuError> {
547 let detection_result = detect_gpu_backends();
548
549 let preference_order = [
551 GpuBackend::Cuda, GpuBackend::Rocm, GpuBackend::Metal, GpuBackend::OpenCL, GpuBackend::Wgpu, GpuBackend::Cpu, ];
558
559 for backend in preference_order.iter() {
560 if detection_result
561 .devices
562 .iter()
563 .any(|d: &GpuInfo| d.backend == *backend)
564 {
565 return Ok(*backend);
566 }
567 }
568
569 Ok(GpuBackend::Cpu)
571}
572
573#[cfg(test)]
574mod tests {
575 use super::*;
576
577 #[test]
578 fn test_gpu_info_creation() {
579 let info = GpuInfo {
580 backend: GpuBackend::Cuda,
581 device_name: "NVIDIA GeForce RTX 3080".to_string(),
582 memory_bytes: Some(10 * 1024 * 1024 * 1024), compute_capability: Some("8.6".to_string()),
584 supports_tensors: true,
585 };
586
587 assert_eq!(info.backend, GpuBackend::Cuda);
588 assert_eq!(info.device_name, "NVIDIA GeForce RTX 3080");
589 assert_eq!(info.memory_bytes, Some(10 * 1024 * 1024 * 1024));
590 assert_eq!(info.compute_capability, Some("8.6".to_string()));
591 assert!(info.supports_tensors);
592 }
593
594 #[test]
595 fn test_gpu_detection_result_with_cpu_fallback() {
596 let result = detect_gpu_backends();
597
598 assert!(!result.devices.is_empty());
600 assert!(result
601 .devices
602 .iter()
603 .any(|d: &GpuInfo| d.backend == GpuBackend::Cpu));
604
605 match result.recommended_backend {
607 GpuBackend::Cuda
608 | GpuBackend::Rocm
609 | GpuBackend::Metal
610 | GpuBackend::OpenCL
611 | GpuBackend::Cpu => {}
612 _ => panic!("Unexpected recommended backend"),
613 }
614 }
615
616 #[test]
617 fn test_check_backend_installation_cpu() {
618 let result = check_backend_installation(GpuBackend::Cpu).expect("Operation failed");
620 assert!(result);
621 }
622
623 #[test]
624 fn test_check_backend_installation_wgpu() {
625 let result = check_backend_installation(GpuBackend::Wgpu).expect("Operation failed");
627 assert!(result);
628 }
629
630 #[test]
631 fn test_check_backend_installation_metal() {
632 let result = check_backend_installation(GpuBackend::Metal).expect("Operation failed");
633 #[cfg(target_os = "macos")]
634 assert!(result);
635 #[cfg(not(target_os = "macos"))]
636 assert!(!result);
637 }
638
639 #[test]
640 fn test_initialize_optimal_backend() {
641 let backend = initialize_optimal_backend().expect("Operation failed");
642
643 match backend {
645 GpuBackend::Cuda
646 | GpuBackend::Rocm
647 | GpuBackend::Wgpu
648 | GpuBackend::Metal
649 | GpuBackend::OpenCL
650 | GpuBackend::Cpu => {}
651 }
652 }
653
654 #[test]
655 fn test_get_device_info_invalid_device() {
656 let result = get_device_info(GpuBackend::Cpu, 100);
658
659 assert!(result.is_err());
660 match result {
661 Err(GpuError::InvalidParameter(_)) => {}
662 _ => panic!("Expected InvalidParameter error"),
663 }
664 }
665
666 #[test]
667 fn test_get_device_info_cpu() {
668 let result = get_device_info(GpuBackend::Cpu, 0);
670
671 assert!(result.is_ok());
672 let info = result.expect("Operation failed");
673 assert_eq!(info.backend, GpuBackend::Cpu);
674 assert_eq!(info.device_name, "CPU");
675 assert!(!info.supports_tensors);
676 }
677
678 #[test]
679 fn test_detect_metal_devices_non_macos() {
680 #[cfg(not(target_os = "macos"))]
681 {
682 let result = detect_metal_devices();
683 assert!(result.is_err());
684 match result {
685 Err(GpuError::BackendNotAvailable(_)) => {}
686 _ => panic!("Expected BackendNotAvailable error"),
687 }
688 }
689 }
690
691 #[test]
692 fn test_gpu_info_clone() {
693 let info = GpuInfo {
694 backend: GpuBackend::Rocm,
695 device_name: "AMD Radeon RX 6900 XT".to_string(),
696 memory_bytes: Some(16 * 1024 * 1024 * 1024), compute_capability: Some("RDNA2".to_string()),
698 supports_tensors: true,
699 };
700
701 let cloned = info.clone();
702 assert_eq!(info.backend, cloned.backend);
703 assert_eq!(info.device_name, cloned.device_name);
704 assert_eq!(info.memory_bytes, cloned.memory_bytes);
705 assert_eq!(info.compute_capability, cloned.compute_capability);
706 assert_eq!(info.supports_tensors, cloned.supports_tensors);
707 }
708
709 #[test]
710 fn test_gpu_detection_result_clone() {
711 let devices = vec![
712 GpuInfo {
713 backend: GpuBackend::Cuda,
714 device_name: "NVIDIA A100".to_string(),
715 memory_bytes: Some(40 * 1024 * 1024 * 1024),
716 compute_capability: Some("8.0".to_string()),
717 supports_tensors: true,
718 },
719 GpuInfo {
720 backend: GpuBackend::Cpu,
721 device_name: "CPU".to_string(),
722 memory_bytes: None,
723 compute_capability: None,
724 supports_tensors: false,
725 },
726 ];
727
728 let result = GpuDetectionResult {
729 devices: devices.clone(),
730 recommended_backend: GpuBackend::Cuda,
731 };
732
733 let cloned = result.clone();
734 assert_eq!(result.devices.len(), cloned.devices.len());
735 assert_eq!(result.recommended_backend, cloned.recommended_backend);
736 }
737
738 #[test]
740 fn test_detect_cuda_deviceserror_handling() {
741 let _ = detect_cuda_devices();
745 }
746
747 #[test]
748 fn test_detect_rocm_deviceserror_handling() {
749 let _ = detect_rocm_devices();
751 }
752
753 #[test]
754 fn test_detect_opencl_deviceserror_handling() {
755 let _ = detect_opencl_devices();
757 }
758
759 #[test]
760 fn test_backend_preference_order() {
761 let result = detect_gpu_backends();
763
764 if result
766 .devices
767 .iter()
768 .any(|d: &GpuInfo| d.backend == GpuBackend::Cuda)
769 {
770 let optimal = initialize_optimal_backend().expect("Operation failed");
772 if result
773 .devices
774 .iter()
775 .filter(|d| d.backend == GpuBackend::Cuda)
776 .count()
777 > 0
778 {
779 assert_eq!(optimal, GpuBackend::Cuda);
780 }
781 }
782 }
783}