scirs2_core/gpu/backends/
mod.rs1use crate::gpu::{GpuBackend, GpuError};
7use std::process::Command;
8
9#[cfg(all(target_os = "macos", feature = "serialization"))]
10use serde_json;
11
12#[cfg(feature = "validation")]
13use regex::Regex;
14
15#[cfg(feature = "cuda")]
17pub mod cuda;
18
19#[cfg(feature = "opencl")]
20pub mod opencl;
21
22#[cfg(feature = "wgpu_backend")]
23pub mod wgpu;
24
25#[cfg(all(feature = "metal", target_os = "macos"))]
26pub mod metal;
27
28#[cfg(all(feature = "metal", target_os = "macos"))]
29pub mod metal_mps;
30
31#[cfg(all(feature = "mpsgraph", target_os = "macos"))]
32pub mod metal_mpsgraph;
33
34#[cfg(feature = "cuda")]
36pub use cuda::{get_optimizer_kernels, CudaContext, CudaStream};
37
38#[cfg(feature = "opencl")]
39pub use opencl::OpenCLContext;
40
41#[cfg(feature = "wgpu_backend")]
42pub use wgpu::WebGPUContext;
43
44#[cfg(all(feature = "metal", target_os = "macos"))]
45pub use metal::{MetalBufferOptions, MetalContext, MetalStorageMode};
46
47#[cfg(all(feature = "metal", target_os = "macos"))]
48pub use metal_mps::{MPSContext, MPSDataType, MPSOperations};
49
50#[cfg(all(feature = "mpsgraph", target_os = "macos"))]
51pub use metal_mpsgraph::MPSGraphContext;
52
53#[derive(Debug, Clone)]
55pub struct GpuInfo {
56 pub backend: GpuBackend,
58 pub device_name: String,
60 pub memory_bytes: Option<u64>,
62 pub compute_capability: Option<String>,
64 pub supports_tensors: bool,
66}
67
68#[derive(Debug, Clone)]
70pub struct GpuDetectionResult {
71 pub devices: Vec<GpuInfo>,
73 pub recommended_backend: GpuBackend,
75}
76
77#[allow(dead_code)]
79pub fn detect_gpu_backends() -> GpuDetectionResult {
80 let mut devices = Vec::new();
81
82 #[cfg(not(test))]
84 {
85 if let Ok(cuda_devices) = detect_cuda_devices() {
87 devices.extend(cuda_devices);
88 }
89
90 if let Ok(rocm_devices) = detect_rocm_devices() {
92 devices.extend(rocm_devices);
93 }
94
95 #[cfg(target_os = "macos")]
97 if let Ok(metal_devices) = detect_metal_devices() {
98 devices.extend(metal_devices);
99 }
100
101 if let Ok(opencl_devices) = detect_opencl_devices() {
103 devices.extend(opencl_devices);
104 }
105 }
106
107 let recommended_backend = if devices
109 .iter()
110 .any(|d: &GpuInfo| d.backend == GpuBackend::Cuda)
111 {
112 GpuBackend::Cuda
113 } else if devices
114 .iter()
115 .any(|d: &GpuInfo| d.backend == GpuBackend::Rocm)
116 {
117 GpuBackend::Rocm
118 } else if devices
119 .iter()
120 .any(|d: &GpuInfo| d.backend == GpuBackend::Metal)
121 {
122 GpuBackend::Metal
123 } else if devices
124 .iter()
125 .any(|d: &GpuInfo| d.backend == GpuBackend::OpenCL)
126 {
127 GpuBackend::OpenCL
128 } else {
129 GpuBackend::Cpu
130 };
131
132 devices.push(GpuInfo {
134 backend: GpuBackend::Cpu,
135 device_name: "CPU".to_string(),
136 memory_bytes: None,
137 compute_capability: None,
138 supports_tensors: false,
139 });
140
141 GpuDetectionResult {
142 devices,
143 recommended_backend,
144 }
145}
146
147#[allow(dead_code)]
149fn detect_rocm_devices() -> Result<Vec<GpuInfo>, GpuError> {
150 let mut devices = Vec::new();
151
152 match Command::new("rocm-smi")
154 .arg("--showproductname")
155 .arg("--showmeminfo")
156 .arg("vram")
157 .arg("--csv")
158 .output()
159 {
160 Ok(output) if output.status.success() => {
161 let output_str = String::from_utf8_lossy(&output.stdout);
162
163 for line in output_str.lines().skip(1) {
164 if line.trim().is_empty() {
166 continue;
167 }
168
169 let parts: Vec<&str> = line.split(',').map(|s| s.trim()).collect();
170 if parts.len() >= 3 {
171 let device_name = parts[1].trim_matches('"').to_string();
172 let memory_str = parts[2].trim_matches('"');
173
174 let memory_mb = memory_str
176 .split_whitespace()
177 .next()
178 .and_then(|s| s.parse::<u64>().ok())
179 .unwrap_or(0)
180 * 1024
181 * 1024; devices.push(GpuInfo {
184 backend: GpuBackend::Rocm,
185 device_name,
186 memory_bytes: Some(memory_mb),
187 compute_capability: Some("RDNA/CDNA".to_string()),
188 supports_tensors: true, });
190 }
191 }
192 }
193 _ => {
194 }
200 }
201
202 if devices.is_empty() {
203 Err(GpuError::BackendNotAvailable("ROCm".to_string()))
204 } else {
205 Ok(devices)
206 }
207}
208
209#[allow(dead_code)]
211fn detect_cuda_devices() -> Result<Vec<GpuInfo>, GpuError> {
212 let mut devices = Vec::new();
213
214 match Command::new("nvidia-smi")
216 .arg("--query-gpu=name,memory.total,compute_cap")
217 .arg("--format=csv,noheader,nounits")
218 .output()
219 {
220 Ok(output) if output.status.success() => {
221 let output_str = String::from_utf8_lossy(&output.stdout);
222
223 for line in output_str.lines() {
224 if line.trim().is_empty() {
225 continue;
226 }
227
228 let parts: Vec<&str> = line.split(',').map(|s| s.trim()).collect();
229 if parts.len() >= 3 {
230 let device_name = parts[0].to_string();
231 let memory_mb = parts[1].parse::<u64>().unwrap_or(0) * 1024 * 1024; let compute_capability = parts[2].to_string();
233
234 let supports_tensors =
236 if let Some(major_str) = compute_capability.split('.').next() {
237 major_str.parse::<u32>().unwrap_or(0) >= 7 } else {
239 false
240 };
241
242 devices.push(GpuInfo {
243 backend: GpuBackend::Cuda,
244 device_name,
245 memory_bytes: Some(memory_mb),
246 compute_capability: Some(compute_capability),
247 supports_tensors,
248 });
249 }
250 }
251 }
252 _ => {
253 }
259 }
260
261 if devices.is_empty() {
262 Err(GpuError::BackendNotAvailable("CUDA".to_string()))
263 } else {
264 Ok(devices)
265 }
266}
267
268#[cfg(target_os = "macos")]
270#[allow(dead_code)]
271fn detect_metal_devices() -> Result<Vec<GpuInfo>, GpuError> {
272 let mut devices = Vec::new();
273
274 match Command::new("system_profiler")
276 .arg("SPDisplaysDataType")
277 .arg("-json")
278 .output()
279 {
280 Ok(output) if output.status.success() => {
281 #[cfg(feature = "serialization")]
283 {
284 use std::str::FromStr;
285 let output_str = String::from_utf8_lossy(&output.stdout);
286
287 if let Ok(json_value) = serde_json::Value::from_str(&output_str) {
288 if let Some(displays) = json_value
289 .get("SPDisplaysDataType")
290 .and_then(|v: &serde_json::Value| v.as_array())
291 {
292 #[cfg(feature = "validation")]
294 let vram_regex = Regex::new(r"(\d+)\s*(GB|MB)").ok();
295
296 for display in displays {
297 if let Some(model) = display
299 .get("sppci_model")
300 .and_then(|v: &serde_json::Value| v.as_str())
301 {
302 let mut gpu_info = GpuInfo {
303 backend: GpuBackend::Metal,
304 device_name: model.to_string(),
305 memory_bytes: None,
306 compute_capability: None,
307 supports_tensors: true,
308 };
309
310 if let Some(vram_str) = display
312 .get("vram_pcie")
313 .and_then(|v: &serde_json::Value| v.as_str())
314 .or_else(|| {
315 display
316 .get("vram")
317 .and_then(|v: &serde_json::Value| v.as_str())
318 })
319 {
320 #[cfg(feature = "validation")]
322 if let Some(captures) =
323 vram_regex.as_ref().and_then(|re| re.captures(vram_str))
324 {
325 if let (Some(value), Some(unit)) =
326 (captures.get(1), captures.get(2))
327 {
328 if let Ok(num) = u64::from_str(value.as_str()) {
329 gpu_info.memory_bytes = Some(match unit.as_str() {
330 "GB" => num * 1024 * 1024 * 1024,
331 "MB" => num * 1024 * 1024,
332 _ => 0,
333 });
334 }
335 }
336 }
337 }
338
339 if let Some(metal_family) = display
341 .get("sppci_metal_family")
342 .and_then(|v: &serde_json::Value| v.as_str())
343 {
344 gpu_info.compute_capability = Some(metal_family.to_string());
345 }
346
347 devices.push(gpu_info);
348 }
349 }
350 }
351 }
352 }
353
354 if devices.is_empty() {
356 #[cfg(feature = "metal")]
358 {
359 use metal::Device;
360 if let Some(device) = Device::system_default() {
361 let name = device.name().to_string();
362 let mut gpu_info = GpuInfo {
363 backend: GpuBackend::Metal,
364 device_name: name.clone(),
365 memory_bytes: None,
366 compute_capability: None,
367 supports_tensors: true,
368 };
369
370 gpu_info.compute_capability = Some("Metal GPU".to_string());
373
374 devices.push(gpu_info);
375 }
376 }
377
378 #[cfg(not(feature = "metal"))]
380 {
381 devices.push(GpuInfo {
382 backend: GpuBackend::Metal,
383 device_name: "Metal GPU".to_string(),
384 memory_bytes: None,
385 compute_capability: None,
386 supports_tensors: true,
387 });
388 }
389 }
390 }
391 _ => {
392 #[cfg(feature = "metal")]
394 {
395 use metal::Device;
396 if let Some(device) = Device::system_default() {
397 devices.push(GpuInfo {
398 backend: GpuBackend::Metal,
399 device_name: device.name().to_string(),
400 memory_bytes: None,
401 compute_capability: None,
402 supports_tensors: true,
403 });
404 } else {
405 return Err(GpuError::BackendNotAvailable("Metal".to_string()));
406 }
407 }
408
409 #[cfg(not(feature = "metal"))]
410 {
411 return Err(GpuError::BackendNotAvailable("Metal".to_string()));
412 }
413 }
414 }
415
416 if devices.is_empty() {
417 Err(GpuError::BackendNotAvailable("Metal".to_string()))
418 } else {
419 Ok(devices)
420 }
421}
422
423#[cfg(not(target_os = "macos"))]
425#[allow(dead_code)]
426fn detect_metal_devices() -> Result<Vec<GpuInfo>, GpuError> {
427 Err(GpuError::BackendNotAvailable(
428 "Metal (not macOS)".to_string(),
429 ))
430}
431
432#[allow(dead_code)]
434fn detect_opencl_devices() -> Result<Vec<GpuInfo>, GpuError> {
435 let mut devices = Vec::new();
436
437 match Command::new("clinfo").arg("--list").output() {
439 Ok(output) if output.status.success() => {
440 let output_str = String::from_utf8_lossy(&output.stdout);
441
442 for line in output_str.lines() {
443 if line.trim().starts_with("Platform") || line.trim().starts_with("Device") {
444 devices.push(GpuInfo {
447 backend: GpuBackend::OpenCL,
448 device_name: "OpenCL Device".to_string(),
449 memory_bytes: None,
450 compute_capability: None,
451 supports_tensors: false,
452 });
453 break; }
455 }
456 }
457 _ => {
458 return Err(GpuError::BackendNotAvailable("OpenCL".to_string()));
459 }
460 }
461
462 if devices.is_empty() {
463 Err(GpuError::BackendNotAvailable("OpenCL".to_string()))
464 } else {
465 Ok(devices)
466 }
467}
468
469#[allow(dead_code)]
471pub fn check_backend_installation(backend: GpuBackend) -> Result<bool, GpuError> {
472 match backend {
473 GpuBackend::Cuda => {
474 match Command::new("nvcc").arg("--version").output() {
476 Ok(output) if output.status.success() => Ok(true),
477 _ => Ok(false),
478 }
479 }
480 GpuBackend::Rocm => {
481 match Command::new("hipcc").arg("--version").output() {
483 Ok(output) if output.status.success() => Ok(true),
484 _ => {
485 match Command::new("rocm-smi").arg("--version").output() {
487 Ok(output) if output.status.success() => Ok(true),
488 _ => Ok(false),
489 }
490 }
491 }
492 }
493 GpuBackend::Metal => {
494 #[cfg(target_os = "macos")]
495 {
496 Ok(true)
498 }
499 #[cfg(not(target_os = "macos"))]
500 {
501 Ok(false)
502 }
503 }
504 GpuBackend::OpenCL => {
505 match Command::new("clinfo").output() {
507 Ok(output) if output.status.success() => Ok(true),
508 _ => Ok(false),
509 }
510 }
511 GpuBackend::Wgpu => {
512 Ok(true)
514 }
515 GpuBackend::Cpu => Ok(true),
516 }
517}
518
519#[allow(dead_code)]
521pub fn get_device_info(backend: GpuBackend, device_id: usize) -> Result<GpuInfo, GpuError> {
522 let detection_result = detect_gpu_backends();
523
524 detection_result
525 .devices
526 .into_iter()
527 .filter(|d| d.backend == backend)
528 .nth(device_id)
529 .ok_or_else(|| {
530 GpuError::InvalidParameter(format!(
531 "Device {device_id} not found for backend {:?}",
532 backend
533 ))
534 })
535}
536
537#[allow(dead_code)]
539pub fn initialize_optimal_backend() -> Result<GpuBackend, GpuError> {
540 let detection_result = detect_gpu_backends();
541
542 let preference_order = [
544 GpuBackend::Cuda, GpuBackend::Rocm, GpuBackend::Metal, GpuBackend::OpenCL, GpuBackend::Wgpu, GpuBackend::Cpu, ];
551
552 for backend in preference_order.iter() {
553 if detection_result
554 .devices
555 .iter()
556 .any(|d: &GpuInfo| d.backend == *backend)
557 {
558 return Ok(*backend);
559 }
560 }
561
562 Ok(GpuBackend::Cpu)
564}
565
566#[cfg(test)]
567mod tests {
568 use super::*;
569
570 #[test]
571 fn test_gpu_info_creation() {
572 let info = GpuInfo {
573 backend: GpuBackend::Cuda,
574 device_name: "NVIDIA GeForce RTX 3080".to_string(),
575 memory_bytes: Some(10 * 1024 * 1024 * 1024), compute_capability: Some("8.6".to_string()),
577 supports_tensors: true,
578 };
579
580 assert_eq!(info.backend, GpuBackend::Cuda);
581 assert_eq!(info.device_name, "NVIDIA GeForce RTX 3080");
582 assert_eq!(info.memory_bytes, Some(10 * 1024 * 1024 * 1024));
583 assert_eq!(info.compute_capability, Some("8.6".to_string()));
584 assert!(info.supports_tensors);
585 }
586
587 #[test]
588 fn test_gpu_detection_result_with_cpu_fallback() {
589 let result = detect_gpu_backends();
590
591 assert!(!result.devices.is_empty());
593 assert!(result
594 .devices
595 .iter()
596 .any(|d: &GpuInfo| d.backend == GpuBackend::Cpu));
597
598 match result.recommended_backend {
600 GpuBackend::Cuda
601 | GpuBackend::Rocm
602 | GpuBackend::Metal
603 | GpuBackend::OpenCL
604 | GpuBackend::Cpu => {}
605 _ => panic!("Unexpected recommended backend"),
606 }
607 }
608
609 #[test]
610 fn test_check_backend_installation_cpu() {
611 let result = check_backend_installation(GpuBackend::Cpu).expect("Operation failed");
613 assert!(result);
614 }
615
616 #[test]
617 fn test_check_backend_installation_wgpu() {
618 let result = check_backend_installation(GpuBackend::Wgpu).expect("Operation failed");
620 assert!(result);
621 }
622
623 #[test]
624 fn test_check_backend_installation_metal() {
625 let result = check_backend_installation(GpuBackend::Metal).expect("Operation failed");
626 #[cfg(target_os = "macos")]
627 assert!(result);
628 #[cfg(not(target_os = "macos"))]
629 assert!(!result);
630 }
631
632 #[test]
633 fn test_initialize_optimal_backend() {
634 let backend = initialize_optimal_backend().expect("Operation failed");
635
636 match backend {
638 GpuBackend::Cuda
639 | GpuBackend::Rocm
640 | GpuBackend::Wgpu
641 | GpuBackend::Metal
642 | GpuBackend::OpenCL
643 | GpuBackend::Cpu => {}
644 }
645 }
646
647 #[test]
648 fn test_get_device_info_invalid_device() {
649 let result = get_device_info(GpuBackend::Cpu, 100);
651
652 assert!(result.is_err());
653 match result {
654 Err(GpuError::InvalidParameter(_)) => {}
655 _ => panic!("Expected InvalidParameter error"),
656 }
657 }
658
659 #[test]
660 fn test_get_device_info_cpu() {
661 let result = get_device_info(GpuBackend::Cpu, 0);
663
664 assert!(result.is_ok());
665 let info = result.expect("Operation failed");
666 assert_eq!(info.backend, GpuBackend::Cpu);
667 assert_eq!(info.device_name, "CPU");
668 assert!(!info.supports_tensors);
669 }
670
671 #[test]
672 fn test_detect_metal_devices_non_macos() {
673 #[cfg(not(target_os = "macos"))]
674 {
675 let result = detect_metal_devices();
676 assert!(result.is_err());
677 match result {
678 Err(GpuError::BackendNotAvailable(_)) => {}
679 _ => panic!("Expected BackendNotAvailable error"),
680 }
681 }
682 }
683
684 #[test]
685 fn test_gpu_info_clone() {
686 let info = GpuInfo {
687 backend: GpuBackend::Rocm,
688 device_name: "AMD Radeon RX 6900 XT".to_string(),
689 memory_bytes: Some(16 * 1024 * 1024 * 1024), compute_capability: Some("RDNA2".to_string()),
691 supports_tensors: true,
692 };
693
694 let cloned = info.clone();
695 assert_eq!(info.backend, cloned.backend);
696 assert_eq!(info.device_name, cloned.device_name);
697 assert_eq!(info.memory_bytes, cloned.memory_bytes);
698 assert_eq!(info.compute_capability, cloned.compute_capability);
699 assert_eq!(info.supports_tensors, cloned.supports_tensors);
700 }
701
702 #[test]
703 fn test_gpu_detection_result_clone() {
704 let devices = vec![
705 GpuInfo {
706 backend: GpuBackend::Cuda,
707 device_name: "NVIDIA A100".to_string(),
708 memory_bytes: Some(40 * 1024 * 1024 * 1024),
709 compute_capability: Some("8.0".to_string()),
710 supports_tensors: true,
711 },
712 GpuInfo {
713 backend: GpuBackend::Cpu,
714 device_name: "CPU".to_string(),
715 memory_bytes: None,
716 compute_capability: None,
717 supports_tensors: false,
718 },
719 ];
720
721 let result = GpuDetectionResult {
722 devices: devices.clone(),
723 recommended_backend: GpuBackend::Cuda,
724 };
725
726 let cloned = result.clone();
727 assert_eq!(result.devices.len(), cloned.devices.len());
728 assert_eq!(result.recommended_backend, cloned.recommended_backend);
729 }
730
731 #[test]
733 fn test_detect_cuda_deviceserror_handling() {
734 let _ = detect_cuda_devices();
738 }
739
740 #[test]
741 fn test_detect_rocm_deviceserror_handling() {
742 let _ = detect_rocm_devices();
744 }
745
746 #[test]
747 fn test_detect_opencl_deviceserror_handling() {
748 let _ = detect_opencl_devices();
750 }
751
752 #[test]
753 fn test_backend_preference_order() {
754 let result = detect_gpu_backends();
756
757 if result
759 .devices
760 .iter()
761 .any(|d: &GpuInfo| d.backend == GpuBackend::Cuda)
762 {
763 let optimal = initialize_optimal_backend().expect("Operation failed");
765 if result
766 .devices
767 .iter()
768 .filter(|d| d.backend == GpuBackend::Cuda)
769 .count()
770 > 0
771 {
772 assert_eq!(optimal, GpuBackend::Cuda);
773 }
774 }
775 }
776}