scirs2_sparse/gpu/
metal.rs

1//! Metal backend for sparse matrix GPU operations on Apple platforms
2//!
3//! This module provides Metal-specific implementations for sparse matrix operations
4//! optimized for Apple Silicon and Intel Macs with discrete GPUs.
5
6use crate::csr_array::CsrArray;
7use crate::error::{SparseError, SparseResult};
8use crate::sparray::SparseArray;
9use scirs2_core::ndarray::{Array1, ArrayView1};
10use scirs2_core::numeric::{Float, SparseElement};
11use std::fmt::Debug;
12
13#[cfg(feature = "gpu")]
14use crate::gpu_kernel_execution::{GpuKernelConfig, MemoryStrategy};
15
16#[cfg(feature = "gpu")]
17pub use scirs2_core::gpu::{GpuBackend, GpuBuffer, GpuContext, GpuDataType, GpuKernelHandle};
18
19#[cfg(feature = "gpu")]
20pub use scirs2_core::GpuError;
21
22/// Metal shader source code for sparse matrix-vector multiplication
23pub const METAL_SPMV_SHADER_SOURCE: &str = r#"
24#include <metal_stdlib>
25using namespace metal;
26
27kernel void spmv_csr_kernel(
28    device const int* indptr [[buffer(0)]],
29    device const int* indices [[buffer(1)]],
30    device const float* data [[buffer(2)]],
31    device const float* x [[buffer(3)]],
32    device float* y [[buffer(4)]],
33    constant int& rows [[buffer(5)]],
34    uint gid [[thread_position_in_grid]]
35) {
36    if (gid >= uint(rows)) return;
37    
38    float sum = 0.0f;
39    int start = indptr[gid];
40    int end = indptr[gid + 1];
41    
42    for (int j = start; j < end; j++) {
43        sum += data[j] * x[indices[j]];
44    }
45    
46    y[gid] = sum;
47}
48
49kernel void spmv_csr_simdgroup_kernel(
50    device const int* indptr [[buffer(0)]],
51    device const int* indices [[buffer(1)]],
52    device const float* data [[buffer(2)]],
53    device const float* x [[buffer(3)]],
54    device float* y [[buffer(4)]],
55    constant int& rows [[buffer(5)]],
56    uint gid [[thread_position_in_grid]],
57    uint simd_lane_id [[thread_index_in_simdgroup]],
58    uint simd_group_id [[simdgroup_index_in_threadgroup]]
59) {
60    if (gid >= uint(rows)) return;
61    
62    int start = indptr[gid];
63    int end = indptr[gid + 1];
64    float sum = 0.0f;
65    
66    // Use SIMD group for better performance on Apple Silicon
67    for (int j = start + simd_lane_id; j < end; j += 32) {
68        sum += data[j] * x[indices[j]];
69    }
70    
71    // SIMD group reduction
72    sum = simd_sum(sum);
73    
74    if (simd_lane_id == 0) {
75        y[gid] = sum;
76    }
77}
78"#;
79
80/// Metal shader for Apple Silicon optimized operations
81pub const METAL_APPLE_SILICON_SHADER_SOURCE: &str = r#"
82#include <metal_stdlib>
83using namespace metal;
84
85kernel void spmv_csr_apple_silicon_kernel(
86    device const int* indptr [[buffer(0)]],
87    device const int* indices [[buffer(1)]],
88    device const float* data [[buffer(2)]],
89    device const float* x [[buffer(3)]],
90    device float* y [[buffer(4)]],
91    constant int& rows [[buffer(5)]],
92    uint gid [[thread_position_in_grid]],
93    uint lid [[thread_position_in_threadgroup]],
94    threadgroup float* shared_data [[threadgroup(0)]]
95) {
96    if (gid >= uint(rows)) return;
97    
98    int start = indptr[gid];
99    int end = indptr[gid + 1];
100    
101    // Use unified memory architecture efficiently
102    shared_data[lid] = 0.0f;
103    threadgroup_barrier(mem_flags::mem_threadgroup);
104    
105    for (int j = start; j < end; j++) {
106        shared_data[lid] += data[j] * x[indices[j]];
107    }
108    
109    threadgroup_barrier(mem_flags::mem_threadgroup);
110    y[gid] = shared_data[lid];
111}
112
113kernel void spmv_csr_neural_engine_prep_kernel(
114    device const int* indptr [[buffer(0)]],
115    device const int* indices [[buffer(1)]],
116    device const float* data [[buffer(2)]],
117    device const float* x [[buffer(3)]],
118    device float* y [[buffer(4)]],
119    constant int& rows [[buffer(5)]],
120    uint gid [[thread_position_in_grid]]
121) {
122    // Prepare data layout for potential Neural Engine acceleration
123    if (gid >= uint(rows)) return;
124    
125    int start = indptr[gid];
126    int end = indptr[gid + 1];
127    float sum = 0.0f;
128    
129    // Use float4 for better throughput on Apple Silicon
130    int j = start;
131    for (; j + 3 < end; j += 4) {
132        float4 data_vec = float4(data[j], data[j+1], data[j+2], data[j+3]);
133        float4 x_vec = float4(
134            x[indices[j]], 
135            x[indices[j+1]], 
136            x[indices[j+2]], 
137            x[indices[j+3]]
138        );
139        float4 prod = data_vec * x_vec;
140        sum += prod.x + prod.y + prod.z + prod.w;
141    }
142    
143    // Handle remaining elements
144    for (; j < end; j++) {
145        sum += data[j] * x[indices[j]];
146    }
147    
148    y[gid] = sum;
149}
150"#;
151
152/// Metal sparse matrix operations
153pub struct MetalSpMatVec {
154    context: Option<scirs2_core::gpu::GpuContext>,
155    kernel_handle: Option<scirs2_core::gpu::GpuKernelHandle>,
156    simdgroup_kernel: Option<scirs2_core::gpu::GpuKernelHandle>,
157    apple_silicon_kernel: Option<scirs2_core::gpu::GpuKernelHandle>,
158    neural_engine_kernel: Option<scirs2_core::gpu::GpuKernelHandle>,
159    device_info: MetalDeviceInfo,
160}
161
162impl MetalSpMatVec {
163    /// Create a new Metal sparse matrix-vector multiplication handler
164    pub fn new() -> SparseResult<Self> {
165        // Try to create Metal context
166        #[cfg(feature = "gpu")]
167        let context = match scirs2_core::gpu::GpuContext::new(scirs2_core::gpu::GpuBackend::Metal) {
168            Ok(ctx) => Some(ctx),
169            Err(_) => None, // Metal not available, will use CPU fallback
170        };
171        #[cfg(not(feature = "gpu"))]
172        let context = None;
173
174        let mut handler = Self {
175            context,
176            kernel_handle: None,
177            simdgroup_kernel: None,
178            apple_silicon_kernel: None,
179            neural_engine_kernel: None,
180            device_info: MetalDeviceInfo::detect(),
181        };
182
183        // Compile kernels if context is available
184        #[cfg(feature = "gpu")]
185        if handler.context.is_some() {
186            let _ = handler.compile_kernels();
187        }
188
189        Ok(handler)
190    }
191
192    /// Compile Metal shaders for sparse matrix operations
193    #[cfg(feature = "gpu")]
194    pub fn compile_kernels(&mut self) -> Result<(), scirs2_core::gpu::GpuError> {
195        if let Some(ref context) = self.context {
196            // Compile kernels using the context
197            self.kernel_handle =
198                context.execute(|compiler| compiler.compile(METAL_SPMV_SHADER_SOURCE).ok());
199
200            self.simdgroup_kernel =
201                context.execute(|compiler| compiler.compile(METAL_SPMV_SHADER_SOURCE).ok());
202
203            // Apple Silicon specific optimizations
204            if self.device_info.is_apple_silicon {
205                self.apple_silicon_kernel = context
206                    .execute(|compiler| compiler.compile(METAL_APPLE_SILICON_SHADER_SOURCE).ok());
207
208                // Neural Engine kernel would compile the same shader separately
209                if self.device_info.has_neural_engine {
210                    self.neural_engine_kernel = context.execute(|compiler| {
211                        compiler.compile(METAL_APPLE_SILICON_SHADER_SOURCE).ok()
212                    });
213                }
214            }
215
216            if self.kernel_handle.is_some() {
217                Ok(())
218            } else {
219                Err(scirs2_core::gpu::GpuError::KernelCompilationError(
220                    "Failed to compile Metal kernels".to_string(),
221                ))
222            }
223        } else {
224            Err(scirs2_core::gpu::GpuError::BackendNotAvailable(
225                "Metal".to_string(),
226            ))
227        }
228    }
229
230    /// Execute Metal sparse matrix-vector multiplication
231    #[cfg(feature = "gpu")]
232    pub fn execute_spmv<T>(
233        &self,
234        matrix: &CsrArray<T>,
235        vector: &ArrayView1<T>,
236        _device: &super::GpuDevice,
237    ) -> SparseResult<Array1<T>>
238    where
239        T: Float + SparseElement + Debug + Copy + scirs2_core::gpu::GpuDataType,
240    {
241        let (rows, cols) = matrix.shape();
242        if cols != vector.len() {
243            return Err(SparseError::DimensionMismatch {
244                expected: cols,
245                found: vector.len(),
246            });
247        }
248
249        if let Some(ref context) = self.context {
250            // Select the best kernel based on device capabilities
251            let kernel = if self.device_info.is_apple_silicon {
252                self.apple_silicon_kernel
253                    .as_ref()
254                    .or(self.simdgroup_kernel.as_ref())
255                    .or(self.kernel_handle.as_ref())
256            } else {
257                self.simdgroup_kernel
258                    .as_ref()
259                    .or(self.kernel_handle.as_ref())
260            };
261
262            if let Some(kernel) = kernel {
263                // Upload data to GPU
264                let indptr_buffer = context.create_buffer_from_slice(
265                    matrix.get_indptr().as_slice().expect("Operation failed"),
266                );
267                let indices_buffer = context.create_buffer_from_slice(
268                    matrix.get_indices().as_slice().expect("Operation failed"),
269                );
270                let data_buffer = context.create_buffer_from_slice(
271                    matrix.get_data().as_slice().expect("Operation failed"),
272                );
273                let vector_buffer =
274                    context.create_buffer_from_slice(vector.as_slice().expect("Operation failed"));
275                let result_buffer = context.create_buffer::<T>(rows);
276
277                // Set kernel parameters
278                kernel.set_buffer("indptr", &indptr_buffer);
279                kernel.set_buffer("indices", &indices_buffer);
280                kernel.set_buffer("data", &data_buffer);
281                kernel.set_buffer("x", &vector_buffer);
282                kernel.set_buffer("y", &result_buffer);
283                kernel.set_u32("num_rows", rows as u32);
284
285                // Configure threadgroup size for Metal
286                let threadgroup_size = self.device_info.max_threadgroup_size.min(256);
287                let grid_size = ((rows + threadgroup_size - 1) / threadgroup_size, 1, 1);
288                let block_size = (threadgroup_size, 1, 1);
289
290                // Execute kernel
291                let args = vec![scirs2_core::gpu::DynamicKernelArg::U32(rows as u32)];
292
293                context
294                    .launch_kernel("spmv_csr_kernel", grid_size, block_size, &args)
295                    .map_err(|e| {
296                        SparseError::ComputationError(format!(
297                            "Metal kernel execution failed: {:?}",
298                            e
299                        ))
300                    })?;
301
302                // Read result back
303                let mut result_vec = vec![T::sparse_zero(); rows];
304                result_buffer.copy_to_host(&mut result_vec).map_err(|e| {
305                    SparseError::ComputationError(format!(
306                        "Failed to copy result from GPU: {:?}",
307                        e
308                    ))
309                })?;
310                Ok(Array1::from_vec(result_vec))
311            } else {
312                Err(SparseError::ComputationError(
313                    "Metal kernel not compiled".to_string(),
314                ))
315            }
316        } else {
317            // Fallback to CPU implementation
318            matrix.dot_vector(vector)
319        }
320    }
321
322    /// Execute optimized Metal sparse matrix-vector multiplication
323    #[cfg(feature = "gpu")]
324    pub fn execute_optimized_spmv<T>(
325        &self,
326        matrix: &CsrArray<T>,
327        vector: &ArrayView1<T>,
328        device: &super::GpuDevice,
329        optimization_level: MetalOptimizationLevel,
330    ) -> SparseResult<Array1<T>>
331    where
332        T: Float + SparseElement + Debug + Copy + super::GpuDataType,
333    {
334        let (rows, cols) = matrix.shape();
335        if cols != vector.len() {
336            return Err(SparseError::DimensionMismatch {
337                expected: cols,
338                found: vector.len(),
339            });
340        }
341
342        // Choose kernel based on optimization level and device capabilities
343        let kernel = match optimization_level {
344            MetalOptimizationLevel::Basic => &self.kernel_handle,
345            MetalOptimizationLevel::SimdGroup => &self.simdgroup_kernel,
346            MetalOptimizationLevel::AppleSilicon => &self.apple_silicon_kernel,
347            MetalOptimizationLevel::NeuralEngine => &self.neural_engine_kernel,
348        };
349
350        if let Some(ref k) = kernel {
351            self.execute_kernel_with_optimization(matrix, vector, device, k, optimization_level)
352        } else {
353            // Fallback to basic kernel if specific optimization not available
354            if let Some(ref basic_kernel) = self.kernel_handle {
355                self.execute_kernel_with_optimization(
356                    matrix,
357                    vector,
358                    device,
359                    basic_kernel,
360                    MetalOptimizationLevel::Basic,
361                )
362            } else {
363                Err(SparseError::ComputationError(
364                    "No Metal kernels available".to_string(),
365                ))
366            }
367        }
368    }
369
370    #[cfg(feature = "gpu")]
371    fn execute_kernel_with_optimization<T>(
372        &self,
373        matrix: &CsrArray<T>,
374        vector: &ArrayView1<T>,
375        _device: &super::GpuDevice,
376        _kernel: &super::GpuKernelHandle,
377        optimization_level: MetalOptimizationLevel,
378    ) -> SparseResult<Array1<T>>
379    where
380        T: Float + SparseElement + Debug + Copy + super::GpuDataType,
381    {
382        let (rows, _) = matrix.shape();
383
384        if let Some(ref context) = self.context {
385            // Upload data to GPU using context
386            let indptr_gpu = context.create_buffer_from_slice(
387                matrix.get_indptr().as_slice().expect("Operation failed"),
388            );
389            let indices_gpu = context.create_buffer_from_slice(
390                matrix.get_indices().as_slice().expect("Operation failed"),
391            );
392            let data_gpu = context
393                .create_buffer_from_slice(matrix.get_data().as_slice().expect("Operation failed"));
394            let vector_gpu =
395                context.create_buffer_from_slice(vector.as_slice().expect("Operation failed"));
396            let result_gpu = context.create_buffer::<T>(rows);
397
398            // Configure launch parameters based on optimization level
399            let (threadgroup_size, _uses_shared_memory) = match optimization_level {
400                MetalOptimizationLevel::Basic => {
401                    (self.device_info.max_threadgroup_size.min(64), false)
402                }
403                MetalOptimizationLevel::SimdGroup => {
404                    (self.device_info.max_threadgroup_size.min(128), false)
405                }
406                MetalOptimizationLevel::AppleSilicon => {
407                    (self.device_info.max_threadgroup_size.min(256), true)
408                }
409                MetalOptimizationLevel::NeuralEngine => {
410                    // Optimize for Neural Engine pipeline
411                    (self.device_info.max_threadgroup_size.min(128), false)
412                }
413            };
414
415            let grid_size = (rows + threadgroup_size - 1) / threadgroup_size;
416
417            // Launch kernel using context
418            let args = vec![scirs2_core::gpu::DynamicKernelArg::U32(rows as u32)];
419
420            // Use appropriate kernel based on optimization level
421            let kernel_name = match optimization_level {
422                MetalOptimizationLevel::Basic => "spmv_csr_kernel",
423                MetalOptimizationLevel::SimdGroup => "spmv_csr_simdgroup_kernel",
424                MetalOptimizationLevel::AppleSilicon => "spmv_csr_apple_silicon_kernel",
425                MetalOptimizationLevel::NeuralEngine => "spmv_csr_neural_engine_kernel",
426            };
427
428            context
429                .launch_kernel(
430                    kernel_name,
431                    (grid_size, 1, 1),
432                    (threadgroup_size, 1, 1),
433                    &args,
434                )
435                .map_err(|e| {
436                    SparseError::ComputationError(format!("Metal kernel execution failed: {:?}", e))
437                })?;
438
439            // Download result
440            let mut result_vec = vec![T::sparse_zero(); rows];
441            result_gpu.copy_to_host(&mut result_vec).map_err(|e| {
442                SparseError::ComputationError(format!("Failed to copy result from GPU: {:?}", e))
443            })?;
444            Ok(Array1::from_vec(result_vec))
445        } else {
446            // Fallback to CPU implementation
447            matrix.dot_vector(vector)
448        }
449    }
450
451    /// Select optimal kernel based on device and matrix characteristics
452    #[cfg(feature = "gpu")]
453    fn select_optimal_kernel<T>(
454        &self,
455        rows: usize,
456        matrix: &CsrArray<T>,
457    ) -> SparseResult<super::GpuKernelHandle>
458    where
459        T: Float + SparseElement + Debug + Copy,
460    {
461        let avg_nnz_per_row = matrix.get_data().len() as f64 / rows as f64;
462
463        // Select kernel based on device capabilities and matrix characteristics
464        if self.device_info.is_apple_silicon && avg_nnz_per_row > 16.0 {
465            // Use Apple Silicon optimized kernel for dense-ish matrices
466            if let Some(ref kernel) = self.apple_silicon_kernel {
467                Ok(kernel.clone())
468            } else if let Some(ref kernel) = self.simdgroup_kernel {
469                Ok(kernel.clone())
470            } else if let Some(ref kernel) = self.kernel_handle {
471                Ok(kernel.clone())
472            } else {
473                Err(SparseError::ComputationError(
474                    "No Metal kernels available".to_string(),
475                ))
476            }
477        } else if self.device_info.supports_simdgroups && avg_nnz_per_row > 5.0 {
478            // Use SIMD group kernel for moderate sparsity
479            if let Some(ref kernel) = self.simdgroup_kernel {
480                Ok(kernel.clone())
481            } else if let Some(ref kernel) = self.kernel_handle {
482                Ok(kernel.clone())
483            } else {
484                Err(SparseError::ComputationError(
485                    "No Metal kernels available".to_string(),
486                ))
487            }
488        } else {
489            // Use basic kernel for very sparse matrices
490            if let Some(ref kernel) = self.kernel_handle {
491                Ok(kernel.clone())
492            } else {
493                Err(SparseError::ComputationError(
494                    "No Metal kernels available".to_string(),
495                ))
496            }
497        }
498    }
499
500    /// CPU fallback implementation
501    #[cfg(not(feature = "gpu"))]
502    pub fn execute_spmv_cpu<T>(
503        &self,
504        matrix: &CsrArray<T>,
505        vector: &ArrayView1<T>,
506    ) -> SparseResult<Array1<T>>
507    where
508        T: Float + SparseElement + Debug + Copy + std::iter::Sum,
509    {
510        matrix.dot_vector(vector)
511    }
512}
513
514impl Default for MetalSpMatVec {
515    fn default() -> Self {
516        Self::new().unwrap_or_else(|_| Self {
517            context: None,
518            kernel_handle: None,
519            simdgroup_kernel: None,
520            apple_silicon_kernel: None,
521            neural_engine_kernel: None,
522            device_info: MetalDeviceInfo::default(),
523        })
524    }
525}
526
527/// Metal optimization levels for sparse matrix operations
528#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
529pub enum MetalOptimizationLevel {
530    /// Basic thread-per-row implementation
531    #[default]
532    Basic,
533    /// SIMD group optimized implementation
534    SimdGroup,
535    /// Apple Silicon specific optimizations
536    AppleSilicon,
537    /// Neural Engine preparation (future feature)
538    NeuralEngine,
539}
540
541/// Metal device information for optimization
542#[derive(Debug)]
543pub struct MetalDeviceInfo {
544    pub max_threadgroup_size: usize,
545    pub shared_memory_size: usize,
546    pub supports_simdgroups: bool,
547    pub is_apple_silicon: bool,
548    pub has_neural_engine: bool,
549    pub device_name: String,
550}
551
552impl MetalDeviceInfo {
553    /// Detect Metal device capabilities
554    pub fn detect() -> Self {
555        // In a real implementation, this would query the Metal runtime
556        // For now, return sensible defaults for Apple Silicon
557        Self {
558            max_threadgroup_size: 1024,
559            shared_memory_size: 32768, // 32KB
560            supports_simdgroups: true,
561            is_apple_silicon: Self::detect_apple_silicon(),
562            has_neural_engine: Self::detect_neural_engine(),
563            device_name: "Apple GPU".to_string(),
564        }
565    }
566
567    fn detect_apple_silicon() -> bool {
568        // Simple detection based on architecture
569        #[cfg(all(target_arch = "aarch64", target_os = "macos"))]
570        {
571            true
572        }
573        #[cfg(not(all(target_arch = "aarch64", target_os = "macos")))]
574        {
575            return false;
576        }
577    }
578
579    fn detect_neural_engine() -> bool {
580        // Neural Engine is available on M1 and later
581        Self::detect_apple_silicon()
582    }
583}
584
585impl Default for MetalDeviceInfo {
586    fn default() -> Self {
587        Self::detect()
588    }
589}
590
591/// Metal memory management for sparse matrices
592pub struct MetalMemoryManager {
593    device_info: MetalDeviceInfo,
594    #[allow(dead_code)]
595    allocated_buffers: Vec<String>,
596}
597
598impl MetalMemoryManager {
599    /// Create a new Metal memory manager
600    pub fn new() -> Self {
601        Self {
602            device_info: MetalDeviceInfo::detect(),
603            allocated_buffers: Vec::new(),
604        }
605    }
606
607    /// Allocate GPU memory for sparse matrix data with Metal-specific optimizations
608    #[cfg(feature = "gpu")]
609    pub fn allocate_sparse_matrix<T>(
610        &mut self,
611        _matrix: &CsrArray<T>,
612        _device: &super::GpuDevice,
613    ) -> Result<MetalMatrixBuffers<T>, super::GpuError>
614    where
615        T: super::GpuDataType + Copy + Float + SparseElement + Debug,
616    {
617        // This functionality should use GpuContext instead of GpuDevice
618        // For now, return an error indicating this needs proper implementation
619        Err(super::GpuError::BackendNotImplemented(
620            super::GpuBackend::Metal,
621        ))
622    }
623
624    /// Get optimal threadgroup size for the current device
625    pub fn optimal_threadgroup_size(&self, problem_size: usize) -> usize {
626        let max_tg_size = self.device_info.max_threadgroup_size;
627
628        if self.device_info.is_apple_silicon {
629            // Apple Silicon prefers larger threadgroups
630            if problem_size < 1000 {
631                max_tg_size.min(128)
632            } else {
633                max_tg_size.min(256)
634            }
635        } else {
636            // Intel/AMD GPUs prefer smaller threadgroups
637            if problem_size < 1000 {
638                max_tg_size.min(64)
639            } else {
640                max_tg_size.min(128)
641            }
642        }
643    }
644
645    /// Check if SIMD group operations are beneficial
646    pub fn should_use_simdgroups<T>(&self, matrix: &CsrArray<T>) -> bool
647    where
648        T: Float + SparseElement + Debug + Copy,
649    {
650        if !self.device_info.supports_simdgroups {
651            return false;
652        }
653
654        let avg_nnz_per_row = matrix.nnz() as f64 / matrix.shape().0 as f64;
655
656        // SIMD groups are beneficial for matrices with moderate to high sparsity
657        avg_nnz_per_row >= 5.0
658    }
659}
660
661impl Default for MetalMemoryManager {
662    fn default() -> Self {
663        Self::new()
664    }
665}
666
667/// Metal storage modes for optimization
668#[derive(Debug, Clone, Copy, PartialEq, Eq)]
669pub enum MetalStorageMode {
670    /// Shared between CPU and GPU (Apple Silicon)
671    Shared,
672    /// Managed by Metal (discrete GPUs)
673    Managed,
674    /// Private to GPU only
675    Private,
676}
677
678/// GPU memory buffers for Metal sparse matrix data
679#[cfg(feature = "gpu")]
680pub struct MetalMatrixBuffers<T: super::GpuDataType> {
681    pub indptr: super::GpuBuffer<usize>,
682    pub indices: super::GpuBuffer<usize>,
683    pub data: super::GpuBuffer<T>,
684}
685
686#[cfg(test)]
687mod tests {
688    use super::*;
689
690    #[test]
691    fn test_metal_spmv_creation() {
692        let metal_spmv = MetalSpMatVec::new();
693        assert!(metal_spmv.is_ok());
694    }
695
696    #[test]
697    fn test_metal_optimization_levels() {
698        let basic = MetalOptimizationLevel::Basic;
699        let simdgroup = MetalOptimizationLevel::SimdGroup;
700        let apple_silicon = MetalOptimizationLevel::AppleSilicon;
701        let neural_engine = MetalOptimizationLevel::NeuralEngine;
702
703        assert_ne!(basic, simdgroup);
704        assert_ne!(simdgroup, apple_silicon);
705        assert_ne!(apple_silicon, neural_engine);
706        assert_eq!(
707            MetalOptimizationLevel::default(),
708            MetalOptimizationLevel::Basic
709        );
710    }
711
712    #[test]
713    fn test_metal_device_info() {
714        let info = MetalDeviceInfo::detect();
715        assert!(info.max_threadgroup_size > 0);
716        assert!(info.shared_memory_size > 0);
717        assert!(!info.device_name.is_empty());
718    }
719
720    #[test]
721    fn test_apple_silicon_detection() {
722        let info = MetalDeviceInfo::detect();
723
724        // Test that detection logic runs without errors
725        #[cfg(all(target_arch = "aarch64", target_os = "macos"))]
726        assert!(info.is_apple_silicon);
727
728        #[cfg(not(all(target_arch = "aarch64", target_os = "macos")))]
729        assert!(!info.is_apple_silicon);
730    }
731
732    #[test]
733    fn test_metal_memory_manager() {
734        let manager = MetalMemoryManager::new();
735        assert_eq!(manager.allocated_buffers.len(), 0);
736        assert!(manager.device_info.max_threadgroup_size > 0);
737
738        // Test threadgroup size selection
739        let tg_size_small = manager.optimal_threadgroup_size(500);
740        let tg_size_large = manager.optimal_threadgroup_size(50000);
741        assert!(tg_size_small > 0);
742        assert!(tg_size_large > 0);
743    }
744
745    #[test]
746    fn test_metal_storage_modes() {
747        let modes = [
748            MetalStorageMode::Shared,
749            MetalStorageMode::Managed,
750            MetalStorageMode::Private,
751        ];
752
753        for mode in &modes {
754            match mode {
755                MetalStorageMode::Shared => (),
756                MetalStorageMode::Managed => (),
757                MetalStorageMode::Private => (),
758            }
759        }
760    }
761
762    #[test]
763    #[allow(clippy::const_is_empty)]
764    fn test_shader_sources() {
765        assert!(!METAL_SPMV_SHADER_SOURCE.is_empty());
766        assert!(!METAL_APPLE_SILICON_SHADER_SOURCE.is_empty());
767
768        // Check that shaders contain expected function names
769        assert!(METAL_SPMV_SHADER_SOURCE.contains("spmv_csr_kernel"));
770        assert!(METAL_SPMV_SHADER_SOURCE.contains("spmv_csr_simdgroup_kernel"));
771        assert!(METAL_APPLE_SILICON_SHADER_SOURCE.contains("spmv_csr_apple_silicon_kernel"));
772        assert!(METAL_APPLE_SILICON_SHADER_SOURCE.contains("spmv_csr_neural_engine_prep_kernel"));
773    }
774}