scirs2_core/gpu/backends/
mod.rs

1//! GPU backend implementations and detection utilities
2//!
3//! This module contains backend-specific implementations for various GPU platforms
4//! and utilities for detecting available GPU backends.
5
6use crate::gpu::{GpuBackend, GpuError};
7use std::process::Command;
8
9#[cfg(target_os = "macos")]
10use serde_json;
11
12#[cfg(feature = "validation")]
13use regex::Regex;
14
15// Backend implementation modules
16#[cfg(feature = "cuda")]
17pub mod cuda;
18
19#[cfg(feature = "opencl")]
20pub mod opencl;
21
22#[cfg(feature = "wgpu_backend")]
23pub mod wgpu;
24
25#[cfg(all(feature = "metal", target_os = "macos"))]
26pub mod metal;
27
28#[cfg(all(feature = "metal", target_os = "macos"))]
29pub mod metal_mps;
30
31#[cfg(all(feature = "mpsgraph", target_os = "macos"))]
32pub mod metal_mpsgraph;
33
34// Re-export backend implementations
35#[cfg(feature = "cuda")]
36pub use cuda::{get_optimizer_kernels, CudaContext, CudaStream};
37
38#[cfg(feature = "opencl")]
39pub use opencl::OpenCLContext;
40
41#[cfg(feature = "wgpu_backend")]
42pub use wgpu::WebGPUContext;
43
44#[cfg(all(feature = "metal", target_os = "macos"))]
45pub use metal::{MetalBufferOptions, MetalContext, MetalStorageMode};
46
47#[cfg(all(feature = "metal", target_os = "macos"))]
48pub use metal_mps::{MPSContext, MPSDataType, MPSOperations};
49
50#[cfg(all(feature = "mpsgraph", target_os = "macos"))]
51pub use metal_mpsgraph::MPSGraphContext;
52
53/// Information about available GPU hardware
54#[derive(Debug, Clone)]
55pub struct GpuInfo {
56    /// The GPU backend type
57    pub backend: GpuBackend,
58    /// Device name
59    pub device_name: String,
60    /// Available memory in bytes
61    pub memory_bytes: Option<u64>,
62    /// Compute capability or equivalent
63    pub compute_capability: Option<String>,
64    /// Whether the device supports tensor operations
65    pub supports_tensors: bool,
66}
67
68/// Detection results for all available GPU backends
69#[derive(Debug, Clone)]
70pub struct GpuDetectionResult {
71    /// Available GPU devices
72    pub devices: Vec<GpuInfo>,
73    /// Recommended backend for scientific computing
74    pub recommended_backend: GpuBackend,
75}
76
77/// Detect available GPU backends and devices
78#[allow(dead_code)]
79pub fn detect_gpu_backends() -> GpuDetectionResult {
80    let mut devices = Vec::new();
81
82    // Skip GPU detection in test environment to avoid segfaults from external commands
83    #[cfg(not(test))]
84    {
85        // Detect CUDA devices
86        if let Ok(cuda_devices) = detect_cuda_devices() {
87            devices.extend(cuda_devices);
88        }
89
90        // Detect ROCm devices
91        if let Ok(rocm_devices) = detect_rocm_devices() {
92            devices.extend(rocm_devices);
93        }
94
95        // Detect Metal devices (macOS)
96        #[cfg(target_os = "macos")]
97        if let Ok(metal_devices) = detect_metal_devices() {
98            devices.extend(metal_devices);
99        }
100
101        // Detect OpenCL devices
102        if let Ok(opencl_devices) = detect_opencl_devices() {
103            devices.extend(opencl_devices);
104        }
105    }
106
107    // Determine recommended backend
108    let recommended_backend = if devices
109        .iter()
110        .any(|d: &GpuInfo| d.backend == GpuBackend::Cuda)
111    {
112        GpuBackend::Cuda
113    } else if devices
114        .iter()
115        .any(|d: &GpuInfo| d.backend == GpuBackend::Rocm)
116    {
117        GpuBackend::Rocm
118    } else if devices
119        .iter()
120        .any(|d: &GpuInfo| d.backend == GpuBackend::Metal)
121    {
122        GpuBackend::Metal
123    } else if devices
124        .iter()
125        .any(|d: &GpuInfo| d.backend == GpuBackend::OpenCL)
126    {
127        GpuBackend::OpenCL
128    } else {
129        GpuBackend::Cpu
130    };
131
132    // Always add CPU fallback
133    devices.push(GpuInfo {
134        backend: GpuBackend::Cpu,
135        device_name: "CPU".to_string(),
136        memory_bytes: None,
137        compute_capability: None,
138        supports_tensors: false,
139    });
140
141    GpuDetectionResult {
142        devices,
143        recommended_backend,
144    }
145}
146
147/// Detect ROCm devices using rocm-smi
148#[allow(dead_code)]
149fn detect_rocm_devices() -> Result<Vec<GpuInfo>, GpuError> {
150    let mut devices = Vec::new();
151
152    // Try to run rocm-smi to detect ROCm devices
153    match Command::new("rocm-smi")
154        .arg("--showproductname")
155        .arg("--showmeminfo")
156        .arg("vram")
157        .arg("--csv")
158        .output()
159    {
160        Ok(output) if output.status.success() => {
161            let output_str = String::from_utf8_lossy(&output.stdout);
162
163            for line in output_str.lines().skip(1) {
164                // Skip header line
165                if line.trim().is_empty() {
166                    continue;
167                }
168
169                let parts: Vec<&str> = line.split(',').map(|s| s.trim()).collect();
170                if parts.len() >= 3 {
171                    let device_name = parts[1].trim_matches('"').to_string();
172                    let memory_str = parts[2].trim_matches('"');
173
174                    // Parse memory (format might be like "16368 MB")
175                    let memory_mb = memory_str
176                        .split_whitespace()
177                        .next()
178                        .and_then(|s| s.parse::<u64>().ok())
179                        .unwrap_or(0)
180                        * 1024
181                        * 1024; // Convert MB to bytes
182
183                    devices.push(GpuInfo {
184                        backend: GpuBackend::Rocm,
185                        device_name,
186                        memory_bytes: Some(memory_mb),
187                        compute_capability: Some("RDNA/CDNA".to_string()),
188                        supports_tensors: true, // Modern AMD GPUs support matrix operations
189                    });
190                }
191            }
192        }
193        _ => {
194            // rocm-smi not available or failed
195            // In a real implementation, we could try other methods like:
196            // - Direct HIP runtime API calls
197            // - /sys/class/drm/cardX/ on Linux
198            // - rocminfo command
199        }
200    }
201
202    if devices.is_empty() {
203        Err(GpuError::BackendNotAvailable("ROCm".to_string()))
204    } else {
205        Ok(devices)
206    }
207}
208
209/// Detect CUDA devices using nvidia-ml-py or nvidia-smi
210#[allow(dead_code)]
211fn detect_cuda_devices() -> Result<Vec<GpuInfo>, GpuError> {
212    let mut devices = Vec::new();
213
214    // Try to run nvidia-smi to detect CUDA devices
215    match Command::new("nvidia-smi")
216        .arg("--query-gpu=name,memory.total,compute_cap")
217        .arg("--format=csv,noheader,nounits")
218        .output()
219    {
220        Ok(output) if output.status.success() => {
221            let output_str = String::from_utf8_lossy(&output.stdout);
222
223            for line in output_str.lines() {
224                if line.trim().is_empty() {
225                    continue;
226                }
227
228                let parts: Vec<&str> = line.split(',').map(|s| s.trim()).collect();
229                if parts.len() >= 3 {
230                    let device_name = parts[0].to_string();
231                    let memory_mb = parts[1].parse::<u64>().unwrap_or(0) * 1024 * 1024; // Convert MB to bytes
232                    let compute_capability = parts[2].to_string();
233
234                    // Parse compute capability to determine tensor core support
235                    let supports_tensors =
236                        if let Some(major_str) = compute_capability.split('.').next() {
237                            major_str.parse::<u32>().unwrap_or(0) >= 7 // Tensor cores available on Volta+ (7.0+)
238                        } else {
239                            false
240                        };
241
242                    devices.push(GpuInfo {
243                        backend: GpuBackend::Cuda,
244                        device_name,
245                        memory_bytes: Some(memory_mb),
246                        compute_capability: Some(compute_capability),
247                        supports_tensors,
248                    });
249                }
250            }
251        }
252        _ => {
253            // nvidia-smi not available or failed
254            // In a real implementation, we could try other methods like:
255            // - Direct CUDA runtime API calls
256            // - nvidia-ml-py if available
257            // - /proc/driver/nvidia/gpus/ on Linux
258        }
259    }
260
261    if devices.is_empty() {
262        Err(GpuError::BackendNotAvailable("CUDA".to_string()))
263    } else {
264        Ok(devices)
265    }
266}
267
268/// Detect Metal devices (macOS only)
269#[cfg(target_os = "macos")]
270#[allow(dead_code)]
271fn detect_metal_devices() -> Result<Vec<GpuInfo>, GpuError> {
272    use std::str::FromStr;
273
274    let mut devices = Vec::new();
275
276    // Try to detect Metal devices using system_profiler
277    match Command::new("system_profiler")
278        .arg("SPDisplaysDataType")
279        .arg("-json")
280        .output()
281    {
282        Ok(output) if output.status.success() => {
283            let output_str = String::from_utf8_lossy(&output.stdout);
284
285            // Try to parse JSON output
286            if let Ok(json_value) = serde_json::Value::from_str(&output_str) {
287                if let Some(displays) = json_value
288                    .get("SPDisplaysDataType")
289                    .and_then(|v| v.as_array())
290                {
291                    // Pre-compile regex outside loop for performance
292                    #[cfg(feature = "validation")]
293                    let vram_regex = Regex::new(r"(\d+)\s*(GB|MB)").ok();
294
295                    for display in displays {
296                        // Extract GPU information from each display
297                        if let Some(model) = display.get("sppci_model").and_then(|v| v.as_str()) {
298                            let mut gpu_info = GpuInfo {
299                                backend: GpuBackend::Metal,
300                                device_name: model.to_string(),
301                                memory_bytes: None,
302                                compute_capability: None,
303                                supports_tensors: true,
304                            };
305
306                            // Try to extract VRAM if available
307                            if let Some(vram_str) = display
308                                .get("vram_pcie")
309                                .and_then(|v| v.as_str())
310                                .or_else(|| display.get("vram").and_then(|v| v.as_str()))
311                            {
312                                // Parse VRAM string like "8 GB" or "8192 MB"
313                                #[cfg(feature = "validation")]
314                                if let Some(captures) =
315                                    vram_regex.as_ref().and_then(|re| re.captures(vram_str))
316                                {
317                                    if let (Some(value), Some(unit)) =
318                                        (captures.get(1), captures.get(2))
319                                    {
320                                        if let Ok(num) = u64::from_str(value.as_str()) {
321                                            gpu_info.memory_bytes = Some(match unit.as_str() {
322                                                "GB" => num * 1024 * 1024 * 1024,
323                                                "MB" => num * 1024 * 1024,
324                                                _ => 0,
325                                            });
326                                        }
327                                    }
328                                }
329                            }
330
331                            // Extract Metal family support
332                            if let Some(metal_family) =
333                                display.get("sppci_metal_family").and_then(|v| v.as_str())
334                            {
335                                gpu_info.compute_capability = Some(metal_family.to_string());
336                            }
337
338                            devices.push(gpu_info);
339                        }
340                    }
341                }
342            }
343
344            // If JSON parsing failed or no devices found, try to detect via Metal API
345            if devices.is_empty() {
346                // Check if Metal is available
347                #[cfg(feature = "metal")]
348                {
349                    use metal::Device;
350                    if let Some(device) = Device::system_default() {
351                        let name = device.name().to_string();
352                        let mut gpu_info = GpuInfo {
353                            backend: GpuBackend::Metal,
354                            device_name: name.clone(),
355                            memory_bytes: None,
356                            compute_capability: None,
357                            supports_tensors: true,
358                        };
359
360                        // GPU family detection would go here
361                        // Note: MTLGPUFamily is not exposed in the current metal crate
362                        gpu_info.compute_capability = Some("Metal GPU".to_string());
363
364                        devices.push(gpu_info);
365                    }
366                }
367
368                // Fallback if Metal crate not available but we're on macOS
369                #[cfg(not(feature = "metal"))]
370                {
371                    devices.push(GpuInfo {
372                        backend: GpuBackend::Metal,
373                        device_name: "Metal GPU".to_string(),
374                        memory_bytes: None,
375                        compute_capability: None,
376                        supports_tensors: true,
377                    });
378                }
379            }
380        }
381        _ => {
382            // system_profiler failed, try Metal API directly
383            #[cfg(feature = "metal")]
384            {
385                use metal::Device;
386                if let Some(device) = Device::system_default() {
387                    devices.push(GpuInfo {
388                        backend: GpuBackend::Metal,
389                        device_name: device.name().to_string(),
390                        memory_bytes: None,
391                        compute_capability: None,
392                        supports_tensors: true,
393                    });
394                } else {
395                    return Err(GpuError::BackendNotAvailable("Metal".to_string()));
396                }
397            }
398
399            #[cfg(not(feature = "metal"))]
400            {
401                return Err(GpuError::BackendNotAvailable("Metal".to_string()));
402            }
403        }
404    }
405
406    if devices.is_empty() {
407        Err(GpuError::BackendNotAvailable("Metal".to_string()))
408    } else {
409        Ok(devices)
410    }
411}
412
413/// Detect Metal devices (non-macOS - not available)
414#[cfg(not(target_os = "macos"))]
415#[allow(dead_code)]
416fn detect_metal_devices() -> Result<Vec<GpuInfo>, GpuError> {
417    Err(GpuError::BackendNotAvailable(
418        "Metal (not macOS)".to_string(),
419    ))
420}
421
422/// Detect OpenCL devices
423#[allow(dead_code)]
424fn detect_opencl_devices() -> Result<Vec<GpuInfo>, GpuError> {
425    let mut devices = Vec::new();
426
427    // Try to detect OpenCL devices using clinfo
428    match Command::new("clinfo").arg("--list").output() {
429        Ok(output) if output.status.success() => {
430            let output_str = String::from_utf8_lossy(&output.stdout);
431
432            for line in output_str.lines() {
433                if line.trim().starts_with("Platform") || line.trim().starts_with("Device") {
434                    // In a real implementation, we would parse clinfo output properly
435                    // For now, just add a generic OpenCL device
436                    devices.push(GpuInfo {
437                        backend: GpuBackend::OpenCL,
438                        device_name: "OpenCL Device".to_string(),
439                        memory_bytes: None,
440                        compute_capability: None,
441                        supports_tensors: false,
442                    });
443                    break; // Just add one for demo
444                }
445            }
446        }
447        _ => {
448            return Err(GpuError::BackendNotAvailable("OpenCL".to_string()));
449        }
450    }
451
452    if devices.is_empty() {
453        Err(GpuError::BackendNotAvailable("OpenCL".to_string()))
454    } else {
455        Ok(devices)
456    }
457}
458
459/// Check if a specific backend is properly installed and functional
460#[allow(dead_code)]
461pub fn check_backend_installation(backend: GpuBackend) -> Result<bool, GpuError> {
462    match backend {
463        GpuBackend::Cuda => {
464            // Check for CUDA installation
465            match Command::new("nvcc").arg("--version").output() {
466                Ok(output) if output.status.success() => Ok(true),
467                _ => Ok(false),
468            }
469        }
470        GpuBackend::Rocm => {
471            // Check for ROCm installation
472            match Command::new("hipcc").arg("--version").output() {
473                Ok(output) if output.status.success() => Ok(true),
474                _ => {
475                    // Also try rocm-smi as an alternative check
476                    match Command::new("rocm-smi").arg("--version").output() {
477                        Ok(output) if output.status.success() => Ok(true),
478                        _ => Ok(false),
479                    }
480                }
481            }
482        }
483        GpuBackend::Metal => {
484            #[cfg(target_os = "macos")]
485            {
486                // Metal is always available on macOS
487                Ok(true)
488            }
489            #[cfg(not(target_os = "macos"))]
490            {
491                Ok(false)
492            }
493        }
494        GpuBackend::OpenCL => {
495            // Check for OpenCL installation
496            match Command::new("clinfo").output() {
497                Ok(output) if output.status.success() => Ok(true),
498                _ => Ok(false),
499            }
500        }
501        GpuBackend::Wgpu => {
502            // WebGPU is always available through wgpu crate
503            Ok(true)
504        }
505        GpuBackend::Cpu => Ok(true),
506    }
507}
508
509/// Get detailed information about a specific GPU device
510#[allow(dead_code)]
511pub fn get_device_info(backend: GpuBackend, device_id: usize) -> Result<GpuInfo, GpuError> {
512    let detection_result = detect_gpu_backends();
513
514    detection_result
515        .devices
516        .into_iter()
517        .filter(|d| d.backend == backend)
518        .nth(device_id)
519        .ok_or_else(|| {
520            GpuError::InvalidParameter(format!(
521                "Device {device_id} not found for backend {:?}",
522                backend
523            ))
524        })
525}
526
527/// Initialize the optimal GPU backend for the current system
528#[allow(dead_code)]
529pub fn initialize_optimal_backend() -> Result<GpuBackend, GpuError> {
530    let detection_result = detect_gpu_backends();
531
532    // Try backends in order of preference for scientific computing
533    let preference_order = [
534        GpuBackend::Cuda,   // Best for scientific computing
535        GpuBackend::Rocm,   // Second best for scientific computing (AMD)
536        GpuBackend::Metal,  // Good on Apple hardware
537        GpuBackend::OpenCL, // Widely compatible
538        GpuBackend::Wgpu,   // Modern cross-platform
539        GpuBackend::Cpu,    // Always available fallback
540    ];
541
542    for backend in preference_order.iter() {
543        if detection_result
544            .devices
545            .iter()
546            .any(|d: &GpuInfo| d.backend == *backend)
547        {
548            return Ok(*backend);
549        }
550    }
551
552    // Should never reach here since CPU is always available
553    Ok(GpuBackend::Cpu)
554}
555
556#[cfg(test)]
557mod tests {
558    use super::*;
559
560    #[test]
561    fn test_gpu_info_creation() {
562        let info = GpuInfo {
563            backend: GpuBackend::Cuda,
564            device_name: "NVIDIA GeForce RTX 3080".to_string(),
565            memory_bytes: Some(10 * 1024 * 1024 * 1024), // 10GB
566            compute_capability: Some("8.6".to_string()),
567            supports_tensors: true,
568        };
569
570        assert_eq!(info.backend, GpuBackend::Cuda);
571        assert_eq!(info.device_name, "NVIDIA GeForce RTX 3080");
572        assert_eq!(info.memory_bytes, Some(10 * 1024 * 1024 * 1024));
573        assert_eq!(info.compute_capability, Some("8.6".to_string()));
574        assert!(info.supports_tensors);
575    }
576
577    #[test]
578    fn test_gpu_detection_result_with_cpu_fallback() {
579        let result = detect_gpu_backends();
580
581        // Should always have at least CPU fallback
582        assert!(!result.devices.is_empty());
583        assert!(result
584            .devices
585            .iter()
586            .any(|d: &GpuInfo| d.backend == GpuBackend::Cpu));
587
588        // Should have a recommended backend
589        match result.recommended_backend {
590            GpuBackend::Cuda
591            | GpuBackend::Rocm
592            | GpuBackend::Metal
593            | GpuBackend::OpenCL
594            | GpuBackend::Cpu => {}
595            _ => panic!("Unexpected recommended backend"),
596        }
597    }
598
599    #[test]
600    fn test_check_backend_installation_cpu() {
601        // CPU should always be available
602        let result = check_backend_installation(GpuBackend::Cpu).expect("Operation failed");
603        assert!(result);
604    }
605
606    #[test]
607    fn test_check_backend_installation_wgpu() {
608        // WebGPU should always be available through wgpu crate
609        let result = check_backend_installation(GpuBackend::Wgpu).expect("Operation failed");
610        assert!(result);
611    }
612
613    #[test]
614    fn test_check_backend_installation_metal() {
615        let result = check_backend_installation(GpuBackend::Metal).expect("Operation failed");
616        #[cfg(target_os = "macos")]
617        assert!(result);
618        #[cfg(not(target_os = "macos"))]
619        assert!(!result);
620    }
621
622    #[test]
623    fn test_initialize_optimal_backend() {
624        let backend = initialize_optimal_backend().expect("Operation failed");
625
626        // Should return a valid backend
627        match backend {
628            GpuBackend::Cuda
629            | GpuBackend::Rocm
630            | GpuBackend::Wgpu
631            | GpuBackend::Metal
632            | GpuBackend::OpenCL
633            | GpuBackend::Cpu => {}
634        }
635    }
636
637    #[test]
638    fn test_get_device_info_invalid_device() {
639        // Try to get info for a non-existent device
640        let result = get_device_info(GpuBackend::Cpu, 100);
641
642        assert!(result.is_err());
643        match result {
644            Err(GpuError::InvalidParameter(_)) => {}
645            _ => panic!("Expected InvalidParameter error"),
646        }
647    }
648
649    #[test]
650    fn test_get_device_info_cpu() {
651        // CPU device should always be available
652        let result = get_device_info(GpuBackend::Cpu, 0);
653
654        assert!(result.is_ok());
655        let info = result.expect("Operation failed");
656        assert_eq!(info.backend, GpuBackend::Cpu);
657        assert_eq!(info.device_name, "CPU");
658        assert!(!info.supports_tensors);
659    }
660
661    #[test]
662    fn test_detect_metal_devices_non_macos() {
663        #[cfg(not(target_os = "macos"))]
664        {
665            let result = detect_metal_devices();
666            assert!(result.is_err());
667            match result {
668                Err(GpuError::BackendNotAvailable(_)) => {}
669                _ => panic!("Expected BackendNotAvailable error"),
670            }
671        }
672    }
673
674    #[test]
675    fn test_gpu_info_clone() {
676        let info = GpuInfo {
677            backend: GpuBackend::Rocm,
678            device_name: "AMD Radeon RX 6900 XT".to_string(),
679            memory_bytes: Some(16 * 1024 * 1024 * 1024), // 16GB
680            compute_capability: Some("RDNA2".to_string()),
681            supports_tensors: true,
682        };
683
684        let cloned = info.clone();
685        assert_eq!(info.backend, cloned.backend);
686        assert_eq!(info.device_name, cloned.device_name);
687        assert_eq!(info.memory_bytes, cloned.memory_bytes);
688        assert_eq!(info.compute_capability, cloned.compute_capability);
689        assert_eq!(info.supports_tensors, cloned.supports_tensors);
690    }
691
692    #[test]
693    fn test_gpu_detection_result_clone() {
694        let devices = vec![
695            GpuInfo {
696                backend: GpuBackend::Cuda,
697                device_name: "NVIDIA A100".to_string(),
698                memory_bytes: Some(40 * 1024 * 1024 * 1024),
699                compute_capability: Some("8.0".to_string()),
700                supports_tensors: true,
701            },
702            GpuInfo {
703                backend: GpuBackend::Cpu,
704                device_name: "CPU".to_string(),
705                memory_bytes: None,
706                compute_capability: None,
707                supports_tensors: false,
708            },
709        ];
710
711        let result = GpuDetectionResult {
712            devices: devices.clone(),
713            recommended_backend: GpuBackend::Cuda,
714        };
715
716        let cloned = result.clone();
717        assert_eq!(result.devices.len(), cloned.devices.len());
718        assert_eq!(result.recommended_backend, cloned.recommended_backend);
719    }
720
721    // Mock tests to verify error handling in detection functions
722    #[test]
723    fn test_detect_cuda_deviceserror_handling() {
724        // In the real implementation, detect_cuda_devices returns an error
725        // when nvidia-smi is not available. We can't easily test this without
726        // mocking the Command execution, but we can at least call the function
727        let _ = detect_cuda_devices();
728    }
729
730    #[test]
731    fn test_detect_rocm_deviceserror_handling() {
732        // Similar to CUDA test
733        let _ = detect_rocm_devices();
734    }
735
736    #[test]
737    fn test_detect_opencl_deviceserror_handling() {
738        // Similar to CUDA test
739        let _ = detect_opencl_devices();
740    }
741
742    #[test]
743    fn test_backend_preference_order() {
744        // Test that initialize_optimal_backend respects the preference order
745        let result = detect_gpu_backends();
746
747        // If we have multiple backends, the recommended should follow preference
748        if result
749            .devices
750            .iter()
751            .any(|d: &GpuInfo| d.backend == GpuBackend::Cuda)
752        {
753            // If CUDA is available, it should be preferred
754            let optimal = initialize_optimal_backend().expect("Operation failed");
755            if result
756                .devices
757                .iter()
758                .filter(|d| d.backend == GpuBackend::Cuda)
759                .count()
760                > 0
761            {
762                assert_eq!(optimal, GpuBackend::Cuda);
763            }
764        }
765    }
766}