scirs2_core/gpu/backends/
mod.rs1use crate::gpu::{GpuBackend, GpuError};
7use std::process::Command;
8
9#[cfg(target_os = "macos")]
10use serde_json;
11
12#[cfg(feature = "validation")]
13use regex::Regex;
14
15#[cfg(feature = "cuda")]
17pub mod cuda;
18
19#[cfg(feature = "opencl")]
20pub mod opencl;
21
22#[cfg(feature = "wgpu_backend")]
23pub mod wgpu;
24
25#[cfg(all(feature = "metal", target_os = "macos"))]
26pub mod metal;
27
28#[cfg(all(feature = "metal", target_os = "macos"))]
29pub mod metal_mps;
30
31#[cfg(all(feature = "mpsgraph", target_os = "macos"))]
32pub mod metal_mpsgraph;
33
34#[cfg(feature = "cuda")]
36pub use cuda::{get_optimizer_kernels, CudaContext, CudaStream};
37
38#[cfg(feature = "opencl")]
39pub use opencl::OpenCLContext;
40
41#[cfg(feature = "wgpu_backend")]
42pub use wgpu::WebGPUContext;
43
44#[cfg(all(feature = "metal", target_os = "macos"))]
45pub use metal::{MetalBufferOptions, MetalContext, MetalStorageMode};
46
47#[cfg(all(feature = "metal", target_os = "macos"))]
48pub use metal_mps::{MPSContext, MPSDataType, MPSOperations};
49
50#[cfg(all(feature = "mpsgraph", target_os = "macos"))]
51pub use metal_mpsgraph::MPSGraphContext;
52
53#[derive(Debug, Clone)]
55pub struct GpuInfo {
56 pub backend: GpuBackend,
58 pub device_name: String,
60 pub memory_bytes: Option<u64>,
62 pub compute_capability: Option<String>,
64 pub supports_tensors: bool,
66}
67
68#[derive(Debug, Clone)]
70pub struct GpuDetectionResult {
71 pub devices: Vec<GpuInfo>,
73 pub recommended_backend: GpuBackend,
75}
76
77#[allow(dead_code)]
79pub fn detect_gpu_backends() -> GpuDetectionResult {
80 let mut devices = Vec::new();
81
82 #[cfg(not(test))]
84 {
85 if let Ok(cuda_devices) = detect_cuda_devices() {
87 devices.extend(cuda_devices);
88 }
89
90 if let Ok(rocm_devices) = detect_rocm_devices() {
92 devices.extend(rocm_devices);
93 }
94
95 #[cfg(target_os = "macos")]
97 if let Ok(metal_devices) = detect_metal_devices() {
98 devices.extend(metal_devices);
99 }
100
101 if let Ok(opencl_devices) = detect_opencl_devices() {
103 devices.extend(opencl_devices);
104 }
105 }
106
107 let recommended_backend = if devices
109 .iter()
110 .any(|d: &GpuInfo| d.backend == GpuBackend::Cuda)
111 {
112 GpuBackend::Cuda
113 } else if devices
114 .iter()
115 .any(|d: &GpuInfo| d.backend == GpuBackend::Rocm)
116 {
117 GpuBackend::Rocm
118 } else if devices
119 .iter()
120 .any(|d: &GpuInfo| d.backend == GpuBackend::Metal)
121 {
122 GpuBackend::Metal
123 } else if devices
124 .iter()
125 .any(|d: &GpuInfo| d.backend == GpuBackend::OpenCL)
126 {
127 GpuBackend::OpenCL
128 } else {
129 GpuBackend::Cpu
130 };
131
132 devices.push(GpuInfo {
134 backend: GpuBackend::Cpu,
135 device_name: "CPU".to_string(),
136 memory_bytes: None,
137 compute_capability: None,
138 supports_tensors: false,
139 });
140
141 GpuDetectionResult {
142 devices,
143 recommended_backend,
144 }
145}
146
147#[allow(dead_code)]
149fn detect_rocm_devices() -> Result<Vec<GpuInfo>, GpuError> {
150 let mut devices = Vec::new();
151
152 match Command::new("rocm-smi")
154 .arg("--showproductname")
155 .arg("--showmeminfo")
156 .arg("vram")
157 .arg("--csv")
158 .output()
159 {
160 Ok(output) if output.status.success() => {
161 let output_str = String::from_utf8_lossy(&output.stdout);
162
163 for line in output_str.lines().skip(1) {
164 if line.trim().is_empty() {
166 continue;
167 }
168
169 let parts: Vec<&str> = line.split(',').map(|s| s.trim()).collect();
170 if parts.len() >= 3 {
171 let device_name = parts[1].trim_matches('"').to_string();
172 let memory_str = parts[2].trim_matches('"');
173
174 let memory_mb = memory_str
176 .split_whitespace()
177 .next()
178 .and_then(|s| s.parse::<u64>().ok())
179 .unwrap_or(0)
180 * 1024
181 * 1024; devices.push(GpuInfo {
184 backend: GpuBackend::Rocm,
185 device_name,
186 memory_bytes: Some(memory_mb),
187 compute_capability: Some("RDNA/CDNA".to_string()),
188 supports_tensors: true, });
190 }
191 }
192 }
193 _ => {
194 }
200 }
201
202 if devices.is_empty() {
203 Err(GpuError::BackendNotAvailable("ROCm".to_string()))
204 } else {
205 Ok(devices)
206 }
207}
208
209#[allow(dead_code)]
211fn detect_cuda_devices() -> Result<Vec<GpuInfo>, GpuError> {
212 let mut devices = Vec::new();
213
214 match Command::new("nvidia-smi")
216 .arg("--query-gpu=name,memory.total,compute_cap")
217 .arg("--format=csv,noheader,nounits")
218 .output()
219 {
220 Ok(output) if output.status.success() => {
221 let output_str = String::from_utf8_lossy(&output.stdout);
222
223 for line in output_str.lines() {
224 if line.trim().is_empty() {
225 continue;
226 }
227
228 let parts: Vec<&str> = line.split(',').map(|s| s.trim()).collect();
229 if parts.len() >= 3 {
230 let device_name = parts[0].to_string();
231 let memory_mb = parts[1].parse::<u64>().unwrap_or(0) * 1024 * 1024; let compute_capability = parts[2].to_string();
233
234 let supports_tensors =
236 if let Some(major_str) = compute_capability.split('.').next() {
237 major_str.parse::<u32>().unwrap_or(0) >= 7 } else {
239 false
240 };
241
242 devices.push(GpuInfo {
243 backend: GpuBackend::Cuda,
244 device_name,
245 memory_bytes: Some(memory_mb),
246 compute_capability: Some(compute_capability),
247 supports_tensors,
248 });
249 }
250 }
251 }
252 _ => {
253 }
259 }
260
261 if devices.is_empty() {
262 Err(GpuError::BackendNotAvailable("CUDA".to_string()))
263 } else {
264 Ok(devices)
265 }
266}
267
268#[cfg(target_os = "macos")]
270#[allow(dead_code)]
271fn detect_metal_devices() -> Result<Vec<GpuInfo>, GpuError> {
272 use std::str::FromStr;
273
274 let mut devices = Vec::new();
275
276 match Command::new("system_profiler")
278 .arg("SPDisplaysDataType")
279 .arg("-json")
280 .output()
281 {
282 Ok(output) if output.status.success() => {
283 let output_str = String::from_utf8_lossy(&output.stdout);
284
285 if let Ok(json_value) = serde_json::Value::from_str(&output_str) {
287 if let Some(displays) = json_value
288 .get("SPDisplaysDataType")
289 .and_then(|v| v.as_array())
290 {
291 #[cfg(feature = "validation")]
293 let vram_regex = Regex::new(r"(\d+)\s*(GB|MB)").ok();
294
295 for display in displays {
296 if let Some(model) = display.get("sppci_model").and_then(|v| v.as_str()) {
298 let mut gpu_info = GpuInfo {
299 backend: GpuBackend::Metal,
300 device_name: model.to_string(),
301 memory_bytes: None,
302 compute_capability: None,
303 supports_tensors: true,
304 };
305
306 if let Some(vram_str) = display
308 .get("vram_pcie")
309 .and_then(|v| v.as_str())
310 .or_else(|| display.get("vram").and_then(|v| v.as_str()))
311 {
312 #[cfg(feature = "validation")]
314 if let Some(captures) =
315 vram_regex.as_ref().and_then(|re| re.captures(vram_str))
316 {
317 if let (Some(value), Some(unit)) =
318 (captures.get(1), captures.get(2))
319 {
320 if let Ok(num) = u64::from_str(value.as_str()) {
321 gpu_info.memory_bytes = Some(match unit.as_str() {
322 "GB" => num * 1024 * 1024 * 1024,
323 "MB" => num * 1024 * 1024,
324 _ => 0,
325 });
326 }
327 }
328 }
329 }
330
331 if let Some(metal_family) =
333 display.get("sppci_metal_family").and_then(|v| v.as_str())
334 {
335 gpu_info.compute_capability = Some(metal_family.to_string());
336 }
337
338 devices.push(gpu_info);
339 }
340 }
341 }
342 }
343
344 if devices.is_empty() {
346 #[cfg(feature = "metal")]
348 {
349 use metal::Device;
350 if let Some(device) = Device::system_default() {
351 let name = device.name().to_string();
352 let mut gpu_info = GpuInfo {
353 backend: GpuBackend::Metal,
354 device_name: name.clone(),
355 memory_bytes: None,
356 compute_capability: None,
357 supports_tensors: true,
358 };
359
360 gpu_info.compute_capability = Some("Metal GPU".to_string());
363
364 devices.push(gpu_info);
365 }
366 }
367
368 #[cfg(not(feature = "metal"))]
370 {
371 devices.push(GpuInfo {
372 backend: GpuBackend::Metal,
373 device_name: "Metal GPU".to_string(),
374 memory_bytes: None,
375 compute_capability: None,
376 supports_tensors: true,
377 });
378 }
379 }
380 }
381 _ => {
382 #[cfg(feature = "metal")]
384 {
385 use metal::Device;
386 if let Some(device) = Device::system_default() {
387 devices.push(GpuInfo {
388 backend: GpuBackend::Metal,
389 device_name: device.name().to_string(),
390 memory_bytes: None,
391 compute_capability: None,
392 supports_tensors: true,
393 });
394 } else {
395 return Err(GpuError::BackendNotAvailable("Metal".to_string()));
396 }
397 }
398
399 #[cfg(not(feature = "metal"))]
400 {
401 return Err(GpuError::BackendNotAvailable("Metal".to_string()));
402 }
403 }
404 }
405
406 if devices.is_empty() {
407 Err(GpuError::BackendNotAvailable("Metal".to_string()))
408 } else {
409 Ok(devices)
410 }
411}
412
413#[cfg(not(target_os = "macos"))]
415#[allow(dead_code)]
416fn detect_metal_devices() -> Result<Vec<GpuInfo>, GpuError> {
417 Err(GpuError::BackendNotAvailable(
418 "Metal (not macOS)".to_string(),
419 ))
420}
421
422#[allow(dead_code)]
424fn detect_opencl_devices() -> Result<Vec<GpuInfo>, GpuError> {
425 let mut devices = Vec::new();
426
427 match Command::new("clinfo").arg("--list").output() {
429 Ok(output) if output.status.success() => {
430 let output_str = String::from_utf8_lossy(&output.stdout);
431
432 for line in output_str.lines() {
433 if line.trim().starts_with("Platform") || line.trim().starts_with("Device") {
434 devices.push(GpuInfo {
437 backend: GpuBackend::OpenCL,
438 device_name: "OpenCL Device".to_string(),
439 memory_bytes: None,
440 compute_capability: None,
441 supports_tensors: false,
442 });
443 break; }
445 }
446 }
447 _ => {
448 return Err(GpuError::BackendNotAvailable("OpenCL".to_string()));
449 }
450 }
451
452 if devices.is_empty() {
453 Err(GpuError::BackendNotAvailable("OpenCL".to_string()))
454 } else {
455 Ok(devices)
456 }
457}
458
459#[allow(dead_code)]
461pub fn check_backend_installation(backend: GpuBackend) -> Result<bool, GpuError> {
462 match backend {
463 GpuBackend::Cuda => {
464 match Command::new("nvcc").arg("--version").output() {
466 Ok(output) if output.status.success() => Ok(true),
467 _ => Ok(false),
468 }
469 }
470 GpuBackend::Rocm => {
471 match Command::new("hipcc").arg("--version").output() {
473 Ok(output) if output.status.success() => Ok(true),
474 _ => {
475 match Command::new("rocm-smi").arg("--version").output() {
477 Ok(output) if output.status.success() => Ok(true),
478 _ => Ok(false),
479 }
480 }
481 }
482 }
483 GpuBackend::Metal => {
484 #[cfg(target_os = "macos")]
485 {
486 Ok(true)
488 }
489 #[cfg(not(target_os = "macos"))]
490 {
491 Ok(false)
492 }
493 }
494 GpuBackend::OpenCL => {
495 match Command::new("clinfo").output() {
497 Ok(output) if output.status.success() => Ok(true),
498 _ => Ok(false),
499 }
500 }
501 GpuBackend::Wgpu => {
502 Ok(true)
504 }
505 GpuBackend::Cpu => Ok(true),
506 }
507}
508
509#[allow(dead_code)]
511pub fn get_device_info(backend: GpuBackend, device_id: usize) -> Result<GpuInfo, GpuError> {
512 let detection_result = detect_gpu_backends();
513
514 detection_result
515 .devices
516 .into_iter()
517 .filter(|d| d.backend == backend)
518 .nth(device_id)
519 .ok_or_else(|| {
520 GpuError::InvalidParameter(format!(
521 "Device {device_id} not found for backend {:?}",
522 backend
523 ))
524 })
525}
526
527#[allow(dead_code)]
529pub fn initialize_optimal_backend() -> Result<GpuBackend, GpuError> {
530 let detection_result = detect_gpu_backends();
531
532 let preference_order = [
534 GpuBackend::Cuda, GpuBackend::Rocm, GpuBackend::Metal, GpuBackend::OpenCL, GpuBackend::Wgpu, GpuBackend::Cpu, ];
541
542 for backend in preference_order.iter() {
543 if detection_result
544 .devices
545 .iter()
546 .any(|d: &GpuInfo| d.backend == *backend)
547 {
548 return Ok(*backend);
549 }
550 }
551
552 Ok(GpuBackend::Cpu)
554}
555
556#[cfg(test)]
557mod tests {
558 use super::*;
559
560 #[test]
561 fn test_gpu_info_creation() {
562 let info = GpuInfo {
563 backend: GpuBackend::Cuda,
564 device_name: "NVIDIA GeForce RTX 3080".to_string(),
565 memory_bytes: Some(10 * 1024 * 1024 * 1024), compute_capability: Some("8.6".to_string()),
567 supports_tensors: true,
568 };
569
570 assert_eq!(info.backend, GpuBackend::Cuda);
571 assert_eq!(info.device_name, "NVIDIA GeForce RTX 3080");
572 assert_eq!(info.memory_bytes, Some(10 * 1024 * 1024 * 1024));
573 assert_eq!(info.compute_capability, Some("8.6".to_string()));
574 assert!(info.supports_tensors);
575 }
576
577 #[test]
578 fn test_gpu_detection_result_with_cpu_fallback() {
579 let result = detect_gpu_backends();
580
581 assert!(!result.devices.is_empty());
583 assert!(result
584 .devices
585 .iter()
586 .any(|d: &GpuInfo| d.backend == GpuBackend::Cpu));
587
588 match result.recommended_backend {
590 GpuBackend::Cuda
591 | GpuBackend::Rocm
592 | GpuBackend::Metal
593 | GpuBackend::OpenCL
594 | GpuBackend::Cpu => {}
595 _ => panic!("Unexpected recommended backend"),
596 }
597 }
598
599 #[test]
600 fn test_check_backend_installation_cpu() {
601 let result = check_backend_installation(GpuBackend::Cpu).expect("Operation failed");
603 assert!(result);
604 }
605
606 #[test]
607 fn test_check_backend_installation_wgpu() {
608 let result = check_backend_installation(GpuBackend::Wgpu).expect("Operation failed");
610 assert!(result);
611 }
612
613 #[test]
614 fn test_check_backend_installation_metal() {
615 let result = check_backend_installation(GpuBackend::Metal).expect("Operation failed");
616 #[cfg(target_os = "macos")]
617 assert!(result);
618 #[cfg(not(target_os = "macos"))]
619 assert!(!result);
620 }
621
622 #[test]
623 fn test_initialize_optimal_backend() {
624 let backend = initialize_optimal_backend().expect("Operation failed");
625
626 match backend {
628 GpuBackend::Cuda
629 | GpuBackend::Rocm
630 | GpuBackend::Wgpu
631 | GpuBackend::Metal
632 | GpuBackend::OpenCL
633 | GpuBackend::Cpu => {}
634 }
635 }
636
637 #[test]
638 fn test_get_device_info_invalid_device() {
639 let result = get_device_info(GpuBackend::Cpu, 100);
641
642 assert!(result.is_err());
643 match result {
644 Err(GpuError::InvalidParameter(_)) => {}
645 _ => panic!("Expected InvalidParameter error"),
646 }
647 }
648
649 #[test]
650 fn test_get_device_info_cpu() {
651 let result = get_device_info(GpuBackend::Cpu, 0);
653
654 assert!(result.is_ok());
655 let info = result.expect("Operation failed");
656 assert_eq!(info.backend, GpuBackend::Cpu);
657 assert_eq!(info.device_name, "CPU");
658 assert!(!info.supports_tensors);
659 }
660
661 #[test]
662 fn test_detect_metal_devices_non_macos() {
663 #[cfg(not(target_os = "macos"))]
664 {
665 let result = detect_metal_devices();
666 assert!(result.is_err());
667 match result {
668 Err(GpuError::BackendNotAvailable(_)) => {}
669 _ => panic!("Expected BackendNotAvailable error"),
670 }
671 }
672 }
673
674 #[test]
675 fn test_gpu_info_clone() {
676 let info = GpuInfo {
677 backend: GpuBackend::Rocm,
678 device_name: "AMD Radeon RX 6900 XT".to_string(),
679 memory_bytes: Some(16 * 1024 * 1024 * 1024), compute_capability: Some("RDNA2".to_string()),
681 supports_tensors: true,
682 };
683
684 let cloned = info.clone();
685 assert_eq!(info.backend, cloned.backend);
686 assert_eq!(info.device_name, cloned.device_name);
687 assert_eq!(info.memory_bytes, cloned.memory_bytes);
688 assert_eq!(info.compute_capability, cloned.compute_capability);
689 assert_eq!(info.supports_tensors, cloned.supports_tensors);
690 }
691
692 #[test]
693 fn test_gpu_detection_result_clone() {
694 let devices = vec![
695 GpuInfo {
696 backend: GpuBackend::Cuda,
697 device_name: "NVIDIA A100".to_string(),
698 memory_bytes: Some(40 * 1024 * 1024 * 1024),
699 compute_capability: Some("8.0".to_string()),
700 supports_tensors: true,
701 },
702 GpuInfo {
703 backend: GpuBackend::Cpu,
704 device_name: "CPU".to_string(),
705 memory_bytes: None,
706 compute_capability: None,
707 supports_tensors: false,
708 },
709 ];
710
711 let result = GpuDetectionResult {
712 devices: devices.clone(),
713 recommended_backend: GpuBackend::Cuda,
714 };
715
716 let cloned = result.clone();
717 assert_eq!(result.devices.len(), cloned.devices.len());
718 assert_eq!(result.recommended_backend, cloned.recommended_backend);
719 }
720
721 #[test]
723 fn test_detect_cuda_deviceserror_handling() {
724 let _ = detect_cuda_devices();
728 }
729
730 #[test]
731 fn test_detect_rocm_deviceserror_handling() {
732 let _ = detect_rocm_devices();
734 }
735
736 #[test]
737 fn test_detect_opencl_deviceserror_handling() {
738 let _ = detect_opencl_devices();
740 }
741
742 #[test]
743 fn test_backend_preference_order() {
744 let result = detect_gpu_backends();
746
747 if result
749 .devices
750 .iter()
751 .any(|d: &GpuInfo| d.backend == GpuBackend::Cuda)
752 {
753 let optimal = initialize_optimal_backend().expect("Operation failed");
755 if result
756 .devices
757 .iter()
758 .filter(|d| d.backend == GpuBackend::Cuda)
759 .count()
760 > 0
761 {
762 assert_eq!(optimal, GpuBackend::Cuda);
763 }
764 }
765 }
766}