Skip to main content

scirs2_core/gpu/backends/
mod.rs

1//! GPU backend implementations and detection utilities
2//!
3//! This module contains backend-specific implementations for various GPU platforms
4//! and utilities for detecting available GPU backends.
5
6use crate::gpu::{GpuBackend, GpuError};
7use std::process::Command;
8
9#[cfg(all(target_os = "macos", feature = "serialization"))]
10use serde_json;
11
12#[cfg(feature = "validation")]
13use regex::Regex;
14
15// Backend implementation modules
16#[cfg(feature = "cuda")]
17pub mod cuda;
18
19#[cfg(feature = "opencl")]
20pub mod opencl;
21
22#[cfg(feature = "wgpu_backend")]
23pub mod wgpu;
24
25#[cfg(all(feature = "metal", target_os = "macos"))]
26pub mod metal;
27
28#[cfg(all(feature = "metal", target_os = "macos"))]
29pub mod metal_mps;
30
31/// MSL compute kernel source strings for the Metal backend.
32///
33/// Each constant is a complete Metal Shading Language kernel that can be
34/// compiled at runtime by `MetalContext::create_compute_pipeline`.
35#[cfg(all(feature = "metal", target_os = "macos"))]
36pub mod msl_kernels;
37
38#[cfg(all(feature = "mpsgraph", target_os = "macos"))]
39pub mod metal_mpsgraph;
40
41// Re-export backend implementations
42#[cfg(feature = "cuda")]
43pub use cuda::{get_optimizer_kernels, CudaContext, CudaStream};
44
45#[cfg(feature = "opencl")]
46pub use opencl::OpenCLContext;
47
48#[cfg(feature = "wgpu_backend")]
49pub use wgpu::WebGPUContext;
50
51#[cfg(all(feature = "metal", target_os = "macos"))]
52pub use metal::{MetalBufferOptions, MetalContext, MetalStorageMode};
53
54#[cfg(all(feature = "metal", target_os = "macos"))]
55pub use metal_mps::{MPSContext, MPSDataType, MPSOperations};
56
57#[cfg(all(feature = "mpsgraph", target_os = "macos"))]
58pub use metal_mpsgraph::MPSGraphContext;
59
60/// Information about available GPU hardware
61#[derive(Debug, Clone)]
62pub struct GpuInfo {
63    /// The GPU backend type
64    pub backend: GpuBackend,
65    /// Device name
66    pub device_name: String,
67    /// Available memory in bytes
68    pub memory_bytes: Option<u64>,
69    /// Compute capability or equivalent
70    pub compute_capability: Option<String>,
71    /// Whether the device supports tensor operations
72    pub supports_tensors: bool,
73}
74
75/// Detection results for all available GPU backends
76#[derive(Debug, Clone)]
77pub struct GpuDetectionResult {
78    /// Available GPU devices
79    pub devices: Vec<GpuInfo>,
80    /// Recommended backend for scientific computing
81    pub recommended_backend: GpuBackend,
82}
83
84/// Detect available GPU backends and devices
85#[allow(dead_code)]
86pub fn detect_gpu_backends() -> GpuDetectionResult {
87    let mut devices = Vec::new();
88
89    // Skip GPU detection in test environment to avoid segfaults from external commands
90    #[cfg(not(test))]
91    {
92        // Detect CUDA devices
93        if let Ok(cuda_devices) = detect_cuda_devices() {
94            devices.extend(cuda_devices);
95        }
96
97        // Detect ROCm devices
98        if let Ok(rocm_devices) = detect_rocm_devices() {
99            devices.extend(rocm_devices);
100        }
101
102        // Detect Metal devices (macOS)
103        #[cfg(target_os = "macos")]
104        if let Ok(metal_devices) = detect_metal_devices() {
105            devices.extend(metal_devices);
106        }
107
108        // Detect OpenCL devices
109        if let Ok(opencl_devices) = detect_opencl_devices() {
110            devices.extend(opencl_devices);
111        }
112    }
113
114    // Determine recommended backend
115    let recommended_backend = if devices
116        .iter()
117        .any(|d: &GpuInfo| d.backend == GpuBackend::Cuda)
118    {
119        GpuBackend::Cuda
120    } else if devices
121        .iter()
122        .any(|d: &GpuInfo| d.backend == GpuBackend::Rocm)
123    {
124        GpuBackend::Rocm
125    } else if devices
126        .iter()
127        .any(|d: &GpuInfo| d.backend == GpuBackend::Metal)
128    {
129        GpuBackend::Metal
130    } else if devices
131        .iter()
132        .any(|d: &GpuInfo| d.backend == GpuBackend::OpenCL)
133    {
134        GpuBackend::OpenCL
135    } else {
136        GpuBackend::Cpu
137    };
138
139    // Always add CPU fallback
140    devices.push(GpuInfo {
141        backend: GpuBackend::Cpu,
142        device_name: "CPU".to_string(),
143        memory_bytes: None,
144        compute_capability: None,
145        supports_tensors: false,
146    });
147
148    GpuDetectionResult {
149        devices,
150        recommended_backend,
151    }
152}
153
154/// Detect ROCm devices using rocm-smi
155#[allow(dead_code)]
156fn detect_rocm_devices() -> Result<Vec<GpuInfo>, GpuError> {
157    let mut devices = Vec::new();
158
159    // Try to run rocm-smi to detect ROCm devices
160    match Command::new("rocm-smi")
161        .arg("--showproductname")
162        .arg("--showmeminfo")
163        .arg("vram")
164        .arg("--csv")
165        .output()
166    {
167        Ok(output) if output.status.success() => {
168            let output_str = String::from_utf8_lossy(&output.stdout);
169
170            for line in output_str.lines().skip(1) {
171                // Skip header line
172                if line.trim().is_empty() {
173                    continue;
174                }
175
176                let parts: Vec<&str> = line.split(',').map(|s| s.trim()).collect();
177                if parts.len() >= 3 {
178                    let device_name = parts[1].trim_matches('"').to_string();
179                    let memory_str = parts[2].trim_matches('"');
180
181                    // Parse memory (format might be like "16368 MB")
182                    let memory_mb = memory_str
183                        .split_whitespace()
184                        .next()
185                        .and_then(|s| s.parse::<u64>().ok())
186                        .unwrap_or(0)
187                        * 1024
188                        * 1024; // Convert MB to bytes
189
190                    devices.push(GpuInfo {
191                        backend: GpuBackend::Rocm,
192                        device_name,
193                        memory_bytes: Some(memory_mb),
194                        compute_capability: Some("RDNA/CDNA".to_string()),
195                        supports_tensors: true, // Modern AMD GPUs support matrix operations
196                    });
197                }
198            }
199        }
200        _ => {
201            // rocm-smi not available or failed
202            // In a real implementation, we could try other methods like:
203            // - Direct HIP runtime API calls
204            // - /sys/class/drm/cardX/ on Linux
205            // - rocminfo command
206        }
207    }
208
209    if devices.is_empty() {
210        Err(GpuError::BackendNotAvailable("ROCm".to_string()))
211    } else {
212        Ok(devices)
213    }
214}
215
216/// Detect CUDA devices using nvidia-ml-py or nvidia-smi
217#[allow(dead_code)]
218fn detect_cuda_devices() -> Result<Vec<GpuInfo>, GpuError> {
219    let mut devices = Vec::new();
220
221    // Try to run nvidia-smi to detect CUDA devices
222    match Command::new("nvidia-smi")
223        .arg("--query-gpu=name,memory.total,compute_cap")
224        .arg("--format=csv,noheader,nounits")
225        .output()
226    {
227        Ok(output) if output.status.success() => {
228            let output_str = String::from_utf8_lossy(&output.stdout);
229
230            for line in output_str.lines() {
231                if line.trim().is_empty() {
232                    continue;
233                }
234
235                let parts: Vec<&str> = line.split(',').map(|s| s.trim()).collect();
236                if parts.len() >= 3 {
237                    let device_name = parts[0].to_string();
238                    let memory_mb = parts[1].parse::<u64>().unwrap_or(0) * 1024 * 1024; // Convert MB to bytes
239                    let compute_capability = parts[2].to_string();
240
241                    // Parse compute capability to determine tensor core support
242                    let supports_tensors =
243                        if let Some(major_str) = compute_capability.split('.').next() {
244                            major_str.parse::<u32>().unwrap_or(0) >= 7 // Tensor cores available on Volta+ (7.0+)
245                        } else {
246                            false
247                        };
248
249                    devices.push(GpuInfo {
250                        backend: GpuBackend::Cuda,
251                        device_name,
252                        memory_bytes: Some(memory_mb),
253                        compute_capability: Some(compute_capability),
254                        supports_tensors,
255                    });
256                }
257            }
258        }
259        _ => {
260            // nvidia-smi not available or failed
261            // In a real implementation, we could try other methods like:
262            // - Direct CUDA runtime API calls
263            // - nvidia-ml-py if available
264            // - /proc/driver/nvidia/gpus/ on Linux
265        }
266    }
267
268    if devices.is_empty() {
269        Err(GpuError::BackendNotAvailable("CUDA".to_string()))
270    } else {
271        Ok(devices)
272    }
273}
274
275/// Detect Metal devices (macOS only)
276#[cfg(target_os = "macos")]
277#[allow(dead_code)]
278fn detect_metal_devices() -> Result<Vec<GpuInfo>, GpuError> {
279    let mut devices = Vec::new();
280
281    // Try to detect Metal devices using system_profiler
282    match Command::new("system_profiler")
283        .arg("SPDisplaysDataType")
284        .arg("-json")
285        .output()
286    {
287        Ok(output) if output.status.success() => {
288            // Try to parse JSON output (requires serialization feature for serde_json)
289            #[cfg(feature = "serialization")]
290            {
291                use std::str::FromStr;
292                let output_str = String::from_utf8_lossy(&output.stdout);
293
294                if let Ok(json_value) = serde_json::Value::from_str(&output_str) {
295                    if let Some(displays) = json_value
296                        .get("SPDisplaysDataType")
297                        .and_then(|v: &serde_json::Value| v.as_array())
298                    {
299                        // Pre-compile regex outside loop for performance
300                        #[cfg(feature = "validation")]
301                        let vram_regex = Regex::new(r"(\d+)\s*(GB|MB)").ok();
302
303                        for display in displays {
304                            // Extract GPU information from each display
305                            if let Some(model) = display
306                                .get("sppci_model")
307                                .and_then(|v: &serde_json::Value| v.as_str())
308                            {
309                                let mut gpu_info = GpuInfo {
310                                    backend: GpuBackend::Metal,
311                                    device_name: model.to_string(),
312                                    memory_bytes: None,
313                                    compute_capability: None,
314                                    supports_tensors: true,
315                                };
316
317                                // Try to extract VRAM if available
318                                if let Some(vram_str) = display
319                                    .get("vram_pcie")
320                                    .and_then(|v: &serde_json::Value| v.as_str())
321                                    .or_else(|| {
322                                        display
323                                            .get("vram")
324                                            .and_then(|v: &serde_json::Value| v.as_str())
325                                    })
326                                {
327                                    // Parse VRAM string like "8 GB" or "8192 MB"
328                                    #[cfg(feature = "validation")]
329                                    if let Some(captures) =
330                                        vram_regex.as_ref().and_then(|re| re.captures(vram_str))
331                                    {
332                                        if let (Some(value), Some(unit)) =
333                                            (captures.get(1), captures.get(2))
334                                        {
335                                            if let Ok(num) = u64::from_str(value.as_str()) {
336                                                gpu_info.memory_bytes = Some(match unit.as_str() {
337                                                    "GB" => num * 1024 * 1024 * 1024,
338                                                    "MB" => num * 1024 * 1024,
339                                                    _ => 0,
340                                                });
341                                            }
342                                        }
343                                    }
344                                }
345
346                                // Extract Metal family support
347                                if let Some(metal_family) = display
348                                    .get("sppci_metal_family")
349                                    .and_then(|v: &serde_json::Value| v.as_str())
350                                {
351                                    gpu_info.compute_capability = Some(metal_family.to_string());
352                                }
353
354                                devices.push(gpu_info);
355                            }
356                        }
357                    }
358                }
359            }
360
361            // If JSON parsing failed, was skipped, or no devices found, try to detect via Metal API
362            if devices.is_empty() {
363                // Check if Metal is available
364                #[cfg(feature = "metal")]
365                {
366                    use metal::Device;
367                    if let Some(device) = Device::system_default() {
368                        let name = device.name().to_string();
369                        let mut gpu_info = GpuInfo {
370                            backend: GpuBackend::Metal,
371                            device_name: name.clone(),
372                            memory_bytes: None,
373                            compute_capability: None,
374                            supports_tensors: true,
375                        };
376
377                        // GPU family detection would go here
378                        // Note: MTLGPUFamily is not exposed in the current metal crate
379                        gpu_info.compute_capability = Some("Metal GPU".to_string());
380
381                        devices.push(gpu_info);
382                    }
383                }
384
385                // Fallback if Metal crate not available but we're on macOS
386                #[cfg(not(feature = "metal"))]
387                {
388                    devices.push(GpuInfo {
389                        backend: GpuBackend::Metal,
390                        device_name: "Metal GPU".to_string(),
391                        memory_bytes: None,
392                        compute_capability: None,
393                        supports_tensors: true,
394                    });
395                }
396            }
397        }
398        _ => {
399            // system_profiler failed, try Metal API directly
400            #[cfg(feature = "metal")]
401            {
402                use metal::Device;
403                if let Some(device) = Device::system_default() {
404                    devices.push(GpuInfo {
405                        backend: GpuBackend::Metal,
406                        device_name: device.name().to_string(),
407                        memory_bytes: None,
408                        compute_capability: None,
409                        supports_tensors: true,
410                    });
411                } else {
412                    return Err(GpuError::BackendNotAvailable("Metal".to_string()));
413                }
414            }
415
416            #[cfg(not(feature = "metal"))]
417            {
418                return Err(GpuError::BackendNotAvailable("Metal".to_string()));
419            }
420        }
421    }
422
423    if devices.is_empty() {
424        Err(GpuError::BackendNotAvailable("Metal".to_string()))
425    } else {
426        Ok(devices)
427    }
428}
429
430/// Detect Metal devices (non-macOS - not available)
431#[cfg(not(target_os = "macos"))]
432#[allow(dead_code)]
433fn detect_metal_devices() -> Result<Vec<GpuInfo>, GpuError> {
434    Err(GpuError::BackendNotAvailable(
435        "Metal (not macOS)".to_string(),
436    ))
437}
438
439/// Detect OpenCL devices
440#[allow(dead_code)]
441fn detect_opencl_devices() -> Result<Vec<GpuInfo>, GpuError> {
442    let mut devices = Vec::new();
443
444    // Try to detect OpenCL devices using clinfo
445    match Command::new("clinfo").arg("--list").output() {
446        Ok(output) if output.status.success() => {
447            let output_str = String::from_utf8_lossy(&output.stdout);
448
449            for line in output_str.lines() {
450                if line.trim().starts_with("Platform") || line.trim().starts_with("Device") {
451                    // In a real implementation, we would parse clinfo output properly
452                    // For now, just add a generic OpenCL device
453                    devices.push(GpuInfo {
454                        backend: GpuBackend::OpenCL,
455                        device_name: "OpenCL Device".to_string(),
456                        memory_bytes: None,
457                        compute_capability: None,
458                        supports_tensors: false,
459                    });
460                    break; // Just add one for demo
461                }
462            }
463        }
464        _ => {
465            return Err(GpuError::BackendNotAvailable("OpenCL".to_string()));
466        }
467    }
468
469    if devices.is_empty() {
470        Err(GpuError::BackendNotAvailable("OpenCL".to_string()))
471    } else {
472        Ok(devices)
473    }
474}
475
476/// Check if a specific backend is properly installed and functional
477#[allow(dead_code)]
478pub fn check_backend_installation(backend: GpuBackend) -> Result<bool, GpuError> {
479    match backend {
480        GpuBackend::Cuda => {
481            // Check for CUDA installation
482            match Command::new("nvcc").arg("--version").output() {
483                Ok(output) if output.status.success() => Ok(true),
484                _ => Ok(false),
485            }
486        }
487        GpuBackend::Rocm => {
488            // Check for ROCm installation
489            match Command::new("hipcc").arg("--version").output() {
490                Ok(output) if output.status.success() => Ok(true),
491                _ => {
492                    // Also try rocm-smi as an alternative check
493                    match Command::new("rocm-smi").arg("--version").output() {
494                        Ok(output) if output.status.success() => Ok(true),
495                        _ => Ok(false),
496                    }
497                }
498            }
499        }
500        GpuBackend::Metal => {
501            #[cfg(target_os = "macos")]
502            {
503                // Metal is always available on macOS
504                Ok(true)
505            }
506            #[cfg(not(target_os = "macos"))]
507            {
508                Ok(false)
509            }
510        }
511        GpuBackend::OpenCL => {
512            // Check for OpenCL installation
513            match Command::new("clinfo").output() {
514                Ok(output) if output.status.success() => Ok(true),
515                _ => Ok(false),
516            }
517        }
518        GpuBackend::Wgpu => {
519            // WebGPU is always available through wgpu crate
520            Ok(true)
521        }
522        GpuBackend::Cpu => Ok(true),
523    }
524}
525
526/// Get detailed information about a specific GPU device
527#[allow(dead_code)]
528pub fn get_device_info(backend: GpuBackend, device_id: usize) -> Result<GpuInfo, GpuError> {
529    let detection_result = detect_gpu_backends();
530
531    detection_result
532        .devices
533        .into_iter()
534        .filter(|d| d.backend == backend)
535        .nth(device_id)
536        .ok_or_else(|| {
537            GpuError::InvalidParameter(format!(
538                "Device {device_id} not found for backend {:?}",
539                backend
540            ))
541        })
542}
543
544/// Initialize the optimal GPU backend for the current system
545#[allow(dead_code)]
546pub fn initialize_optimal_backend() -> Result<GpuBackend, GpuError> {
547    let detection_result = detect_gpu_backends();
548
549    // Try backends in order of preference for scientific computing
550    let preference_order = [
551        GpuBackend::Cuda,   // Best for scientific computing
552        GpuBackend::Rocm,   // Second best for scientific computing (AMD)
553        GpuBackend::Metal,  // Good on Apple hardware
554        GpuBackend::OpenCL, // Widely compatible
555        GpuBackend::Wgpu,   // Modern cross-platform
556        GpuBackend::Cpu,    // Always available fallback
557    ];
558
559    for backend in preference_order.iter() {
560        if detection_result
561            .devices
562            .iter()
563            .any(|d: &GpuInfo| d.backend == *backend)
564        {
565            return Ok(*backend);
566        }
567    }
568
569    // Should never reach here since CPU is always available
570    Ok(GpuBackend::Cpu)
571}
572
573#[cfg(test)]
574mod tests {
575    use super::*;
576
577    #[test]
578    fn test_gpu_info_creation() {
579        let info = GpuInfo {
580            backend: GpuBackend::Cuda,
581            device_name: "NVIDIA GeForce RTX 3080".to_string(),
582            memory_bytes: Some(10 * 1024 * 1024 * 1024), // 10GB
583            compute_capability: Some("8.6".to_string()),
584            supports_tensors: true,
585        };
586
587        assert_eq!(info.backend, GpuBackend::Cuda);
588        assert_eq!(info.device_name, "NVIDIA GeForce RTX 3080");
589        assert_eq!(info.memory_bytes, Some(10 * 1024 * 1024 * 1024));
590        assert_eq!(info.compute_capability, Some("8.6".to_string()));
591        assert!(info.supports_tensors);
592    }
593
594    #[test]
595    fn test_gpu_detection_result_with_cpu_fallback() {
596        let result = detect_gpu_backends();
597
598        // Should always have at least CPU fallback
599        assert!(!result.devices.is_empty());
600        assert!(result
601            .devices
602            .iter()
603            .any(|d: &GpuInfo| d.backend == GpuBackend::Cpu));
604
605        // Should have a recommended backend
606        match result.recommended_backend {
607            GpuBackend::Cuda
608            | GpuBackend::Rocm
609            | GpuBackend::Metal
610            | GpuBackend::OpenCL
611            | GpuBackend::Cpu => {}
612            _ => panic!("Unexpected recommended backend"),
613        }
614    }
615
616    #[test]
617    fn test_check_backend_installation_cpu() {
618        // CPU should always be available
619        let result = check_backend_installation(GpuBackend::Cpu).expect("Operation failed");
620        assert!(result);
621    }
622
623    #[test]
624    fn test_check_backend_installation_wgpu() {
625        // WebGPU should always be available through wgpu crate
626        let result = check_backend_installation(GpuBackend::Wgpu).expect("Operation failed");
627        assert!(result);
628    }
629
630    #[test]
631    fn test_check_backend_installation_metal() {
632        let result = check_backend_installation(GpuBackend::Metal).expect("Operation failed");
633        #[cfg(target_os = "macos")]
634        assert!(result);
635        #[cfg(not(target_os = "macos"))]
636        assert!(!result);
637    }
638
639    #[test]
640    fn test_initialize_optimal_backend() {
641        let backend = initialize_optimal_backend().expect("Operation failed");
642
643        // Should return a valid backend
644        match backend {
645            GpuBackend::Cuda
646            | GpuBackend::Rocm
647            | GpuBackend::Wgpu
648            | GpuBackend::Metal
649            | GpuBackend::OpenCL
650            | GpuBackend::Cpu => {}
651        }
652    }
653
654    #[test]
655    fn test_get_device_info_invalid_device() {
656        // Try to get info for a non-existent device
657        let result = get_device_info(GpuBackend::Cpu, 100);
658
659        assert!(result.is_err());
660        match result {
661            Err(GpuError::InvalidParameter(_)) => {}
662            _ => panic!("Expected InvalidParameter error"),
663        }
664    }
665
666    #[test]
667    fn test_get_device_info_cpu() {
668        // CPU device should always be available
669        let result = get_device_info(GpuBackend::Cpu, 0);
670
671        assert!(result.is_ok());
672        let info = result.expect("Operation failed");
673        assert_eq!(info.backend, GpuBackend::Cpu);
674        assert_eq!(info.device_name, "CPU");
675        assert!(!info.supports_tensors);
676    }
677
678    #[test]
679    fn test_detect_metal_devices_non_macos() {
680        #[cfg(not(target_os = "macos"))]
681        {
682            let result = detect_metal_devices();
683            assert!(result.is_err());
684            match result {
685                Err(GpuError::BackendNotAvailable(_)) => {}
686                _ => panic!("Expected BackendNotAvailable error"),
687            }
688        }
689    }
690
691    #[test]
692    fn test_gpu_info_clone() {
693        let info = GpuInfo {
694            backend: GpuBackend::Rocm,
695            device_name: "AMD Radeon RX 6900 XT".to_string(),
696            memory_bytes: Some(16 * 1024 * 1024 * 1024), // 16GB
697            compute_capability: Some("RDNA2".to_string()),
698            supports_tensors: true,
699        };
700
701        let cloned = info.clone();
702        assert_eq!(info.backend, cloned.backend);
703        assert_eq!(info.device_name, cloned.device_name);
704        assert_eq!(info.memory_bytes, cloned.memory_bytes);
705        assert_eq!(info.compute_capability, cloned.compute_capability);
706        assert_eq!(info.supports_tensors, cloned.supports_tensors);
707    }
708
709    #[test]
710    fn test_gpu_detection_result_clone() {
711        let devices = vec![
712            GpuInfo {
713                backend: GpuBackend::Cuda,
714                device_name: "NVIDIA A100".to_string(),
715                memory_bytes: Some(40 * 1024 * 1024 * 1024),
716                compute_capability: Some("8.0".to_string()),
717                supports_tensors: true,
718            },
719            GpuInfo {
720                backend: GpuBackend::Cpu,
721                device_name: "CPU".to_string(),
722                memory_bytes: None,
723                compute_capability: None,
724                supports_tensors: false,
725            },
726        ];
727
728        let result = GpuDetectionResult {
729            devices: devices.clone(),
730            recommended_backend: GpuBackend::Cuda,
731        };
732
733        let cloned = result.clone();
734        assert_eq!(result.devices.len(), cloned.devices.len());
735        assert_eq!(result.recommended_backend, cloned.recommended_backend);
736    }
737
738    // Mock tests to verify error handling in detection functions
739    #[test]
740    fn test_detect_cuda_deviceserror_handling() {
741        // In the real implementation, detect_cuda_devices returns an error
742        // when nvidia-smi is not available. We can't easily test this without
743        // mocking the Command execution, but we can at least call the function
744        let _ = detect_cuda_devices();
745    }
746
747    #[test]
748    fn test_detect_rocm_deviceserror_handling() {
749        // Similar to CUDA test
750        let _ = detect_rocm_devices();
751    }
752
753    #[test]
754    fn test_detect_opencl_deviceserror_handling() {
755        // Similar to CUDA test
756        let _ = detect_opencl_devices();
757    }
758
759    #[test]
760    fn test_backend_preference_order() {
761        // Test that initialize_optimal_backend respects the preference order
762        let result = detect_gpu_backends();
763
764        // If we have multiple backends, the recommended should follow preference
765        if result
766            .devices
767            .iter()
768            .any(|d: &GpuInfo| d.backend == GpuBackend::Cuda)
769        {
770            // If CUDA is available, it should be preferred
771            let optimal = initialize_optimal_backend().expect("Operation failed");
772            if result
773                .devices
774                .iter()
775                .filter(|d| d.backend == GpuBackend::Cuda)
776                .count()
777                > 0
778            {
779                assert_eq!(optimal, GpuBackend::Cuda);
780            }
781        }
782    }
783}