scirs2_sparse/gpu/
metal.rs

1//! Metal backend for sparse matrix GPU operations on Apple platforms
2//!
3//! This module provides Metal-specific implementations for sparse matrix operations
4//! optimized for Apple Silicon and Intel Macs with discrete GPUs.
5
6use crate::csr_array::CsrArray;
7use crate::error::{SparseError, SparseResult};
8use crate::sparray::SparseArray;
9use scirs2_core::ndarray::{Array1, ArrayView1};
10use scirs2_core::numeric::{Float, SparseElement};
11use std::fmt::Debug;
12
13#[cfg(feature = "gpu")]
14use crate::gpu_kernel_execution::{GpuKernelConfig, MemoryStrategy};
15
16#[cfg(feature = "gpu")]
17pub use scirs2_core::gpu::{GpuBackend, GpuBuffer, GpuContext, GpuDataType, GpuKernelHandle};
18
19#[cfg(feature = "gpu")]
20pub use scirs2_core::GpuError;
21
22/// Metal shader source code for sparse matrix-vector multiplication
23pub const METAL_SPMV_SHADER_SOURCE: &str = r#"
24#include <metal_stdlib>
25using namespace metal;
26
27kernel void spmv_csr_kernel(
28    device const int* indptr [[buffer(0)]],
29    device const int* indices [[buffer(1)]],
30    device const float* data [[buffer(2)]],
31    device const float* x [[buffer(3)]],
32    device float* y [[buffer(4)]],
33    constant int& rows [[buffer(5)]],
34    uint gid [[thread_position_in_grid]]
35) {
36    if (gid >= uint(rows)) return;
37    
38    float sum = 0.0f;
39    int start = indptr[gid];
40    int end = indptr[gid + 1];
41    
42    for (int j = start; j < end; j++) {
43        sum += data[j] * x[indices[j]];
44    }
45    
46    y[gid] = sum;
47}
48
49kernel void spmv_csr_simdgroup_kernel(
50    device const int* indptr [[buffer(0)]],
51    device const int* indices [[buffer(1)]],
52    device const float* data [[buffer(2)]],
53    device const float* x [[buffer(3)]],
54    device float* y [[buffer(4)]],
55    constant int& rows [[buffer(5)]],
56    uint gid [[thread_position_in_grid]],
57    uint simd_lane_id [[thread_index_in_simdgroup]],
58    uint simd_group_id [[simdgroup_index_in_threadgroup]]
59) {
60    if (gid >= uint(rows)) return;
61    
62    int start = indptr[gid];
63    int end = indptr[gid + 1];
64    float sum = 0.0f;
65    
66    // Use SIMD group for better performance on Apple Silicon
67    for (int j = start + simd_lane_id; j < end; j += 32) {
68        sum += data[j] * x[indices[j]];
69    }
70    
71    // SIMD group reduction
72    sum = simd_sum(sum);
73    
74    if (simd_lane_id == 0) {
75        y[gid] = sum;
76    }
77}
78"#;
79
80/// Metal shader for Apple Silicon optimized operations
81pub const METAL_APPLE_SILICON_SHADER_SOURCE: &str = r#"
82#include <metal_stdlib>
83using namespace metal;
84
85kernel void spmv_csr_apple_silicon_kernel(
86    device const int* indptr [[buffer(0)]],
87    device const int* indices [[buffer(1)]],
88    device const float* data [[buffer(2)]],
89    device const float* x [[buffer(3)]],
90    device float* y [[buffer(4)]],
91    constant int& rows [[buffer(5)]],
92    uint gid [[thread_position_in_grid]],
93    uint lid [[thread_position_in_threadgroup]],
94    threadgroup float* shared_data [[threadgroup(0)]]
95) {
96    if (gid >= uint(rows)) return;
97    
98    int start = indptr[gid];
99    int end = indptr[gid + 1];
100    
101    // Use unified memory architecture efficiently
102    shared_data[lid] = 0.0f;
103    threadgroup_barrier(mem_flags::mem_threadgroup);
104    
105    for (int j = start; j < end; j++) {
106        shared_data[lid] += data[j] * x[indices[j]];
107    }
108    
109    threadgroup_barrier(mem_flags::mem_threadgroup);
110    y[gid] = shared_data[lid];
111}
112
113kernel void spmv_csr_neural_engine_prep_kernel(
114    device const int* indptr [[buffer(0)]],
115    device const int* indices [[buffer(1)]],
116    device const float* data [[buffer(2)]],
117    device const float* x [[buffer(3)]],
118    device float* y [[buffer(4)]],
119    constant int& rows [[buffer(5)]],
120    uint gid [[thread_position_in_grid]]
121) {
122    // Prepare data layout for potential Neural Engine acceleration
123    if (gid >= uint(rows)) return;
124    
125    int start = indptr[gid];
126    int end = indptr[gid + 1];
127    float sum = 0.0f;
128    
129    // Use float4 for better throughput on Apple Silicon
130    int j = start;
131    for (; j + 3 < end; j += 4) {
132        float4 data_vec = float4(data[j], data[j+1], data[j+2], data[j+3]);
133        float4 x_vec = float4(
134            x[indices[j]], 
135            x[indices[j+1]], 
136            x[indices[j+2]], 
137            x[indices[j+3]]
138        );
139        float4 prod = data_vec * x_vec;
140        sum += prod.x + prod.y + prod.z + prod.w;
141    }
142    
143    // Handle remaining elements
144    for (; j < end; j++) {
145        sum += data[j] * x[indices[j]];
146    }
147    
148    y[gid] = sum;
149}
150"#;
151
152/// Metal sparse matrix operations
153pub struct MetalSpMatVec {
154    context: Option<scirs2_core::gpu::GpuContext>,
155    kernel_handle: Option<scirs2_core::gpu::GpuKernelHandle>,
156    simdgroup_kernel: Option<scirs2_core::gpu::GpuKernelHandle>,
157    apple_silicon_kernel: Option<scirs2_core::gpu::GpuKernelHandle>,
158    neural_engine_kernel: Option<scirs2_core::gpu::GpuKernelHandle>,
159    device_info: MetalDeviceInfo,
160}
161
162impl MetalSpMatVec {
163    /// Create a new Metal sparse matrix-vector multiplication handler
164    pub fn new() -> SparseResult<Self> {
165        // Try to create Metal context
166        #[cfg(feature = "gpu")]
167        let context = match scirs2_core::gpu::GpuContext::new(scirs2_core::gpu::GpuBackend::Metal) {
168            Ok(ctx) => Some(ctx),
169            Err(_) => None, // Metal not available, will use CPU fallback
170        };
171        #[cfg(not(feature = "gpu"))]
172        let context = None;
173
174        let mut handler = Self {
175            context,
176            kernel_handle: None,
177            simdgroup_kernel: None,
178            apple_silicon_kernel: None,
179            neural_engine_kernel: None,
180            device_info: MetalDeviceInfo::detect(),
181        };
182
183        // Compile kernels if context is available
184        #[cfg(feature = "gpu")]
185        if handler.context.is_some() {
186            let _ = handler.compile_kernels();
187        }
188
189        Ok(handler)
190    }
191
192    /// Compile Metal shaders for sparse matrix operations
193    #[cfg(feature = "gpu")]
194    pub fn compile_kernels(&mut self) -> Result<(), scirs2_core::gpu::GpuError> {
195        if let Some(ref context) = self.context {
196            // Compile kernels using the context
197            self.kernel_handle =
198                context.execute(|compiler| compiler.compile(METAL_SPMV_SHADER_SOURCE).ok());
199
200            self.simdgroup_kernel =
201                context.execute(|compiler| compiler.compile(METAL_SPMV_SHADER_SOURCE).ok());
202
203            // Apple Silicon specific optimizations
204            if self.device_info.is_apple_silicon {
205                self.apple_silicon_kernel = context
206                    .execute(|compiler| compiler.compile(METAL_APPLE_SILICON_SHADER_SOURCE).ok());
207
208                // Neural Engine kernel would compile the same shader separately
209                if self.device_info.has_neural_engine {
210                    self.neural_engine_kernel = context.execute(|compiler| {
211                        compiler.compile(METAL_APPLE_SILICON_SHADER_SOURCE).ok()
212                    });
213                }
214            }
215
216            if self.kernel_handle.is_some() {
217                Ok(())
218            } else {
219                Err(scirs2_core::gpu::GpuError::KernelCompilationError(
220                    "Failed to compile Metal kernels".to_string(),
221                ))
222            }
223        } else {
224            Err(scirs2_core::gpu::GpuError::BackendNotAvailable(
225                "Metal".to_string(),
226            ))
227        }
228    }
229
230    /// Execute Metal sparse matrix-vector multiplication
231    #[cfg(feature = "gpu")]
232    pub fn execute_spmv<T>(
233        &self,
234        matrix: &CsrArray<T>,
235        vector: &ArrayView1<T>,
236        _device: &super::GpuDevice,
237    ) -> SparseResult<Array1<T>>
238    where
239        T: Float + SparseElement + Debug + Copy + scirs2_core::gpu::GpuDataType,
240    {
241        let (rows, cols) = matrix.shape();
242        if cols != vector.len() {
243            return Err(SparseError::DimensionMismatch {
244                expected: cols,
245                found: vector.len(),
246            });
247        }
248
249        if let Some(ref context) = self.context {
250            // Select the best kernel based on device capabilities
251            let kernel = if self.device_info.is_apple_silicon {
252                self.apple_silicon_kernel
253                    .as_ref()
254                    .or(self.simdgroup_kernel.as_ref())
255                    .or(self.kernel_handle.as_ref())
256            } else {
257                self.simdgroup_kernel
258                    .as_ref()
259                    .or(self.kernel_handle.as_ref())
260            };
261
262            if let Some(kernel) = kernel {
263                // Upload data to GPU
264                let indptr_buffer =
265                    context.create_buffer_from_slice(matrix.get_indptr().as_slice().unwrap());
266                let indices_buffer =
267                    context.create_buffer_from_slice(matrix.get_indices().as_slice().unwrap());
268                let data_buffer =
269                    context.create_buffer_from_slice(matrix.get_data().as_slice().unwrap());
270                let vector_buffer = context.create_buffer_from_slice(vector.as_slice().unwrap());
271                let result_buffer = context.create_buffer::<T>(rows);
272
273                // Set kernel parameters
274                kernel.set_buffer("indptr", &indptr_buffer);
275                kernel.set_buffer("indices", &indices_buffer);
276                kernel.set_buffer("data", &data_buffer);
277                kernel.set_buffer("x", &vector_buffer);
278                kernel.set_buffer("y", &result_buffer);
279                kernel.set_u32("num_rows", rows as u32);
280
281                // Configure threadgroup size for Metal
282                let threadgroup_size = self.device_info.max_threadgroup_size.min(256);
283                let grid_size = ((rows + threadgroup_size - 1) / threadgroup_size, 1, 1);
284                let block_size = (threadgroup_size, 1, 1);
285
286                // Execute kernel
287                let args = vec![scirs2_core::gpu::DynamicKernelArg::U32(rows as u32)];
288
289                context
290                    .launch_kernel("spmv_csr_kernel", grid_size, block_size, &args)
291                    .map_err(|e| {
292                        SparseError::ComputationError(format!(
293                            "Metal kernel execution failed: {:?}",
294                            e
295                        ))
296                    })?;
297
298                // Read result back
299                let mut result_vec = vec![T::sparse_zero(); rows];
300                result_buffer.copy_to_host(&mut result_vec).map_err(|e| {
301                    SparseError::ComputationError(format!(
302                        "Failed to copy result from GPU: {:?}",
303                        e
304                    ))
305                })?;
306                Ok(Array1::from_vec(result_vec))
307            } else {
308                Err(SparseError::ComputationError(
309                    "Metal kernel not compiled".to_string(),
310                ))
311            }
312        } else {
313            // Fallback to CPU implementation
314            matrix.dot_vector(vector)
315        }
316    }
317
318    /// Execute optimized Metal sparse matrix-vector multiplication
319    #[cfg(feature = "gpu")]
320    pub fn execute_optimized_spmv<T>(
321        &self,
322        matrix: &CsrArray<T>,
323        vector: &ArrayView1<T>,
324        device: &super::GpuDevice,
325        optimization_level: MetalOptimizationLevel,
326    ) -> SparseResult<Array1<T>>
327    where
328        T: Float + SparseElement + Debug + Copy + super::GpuDataType,
329    {
330        let (rows, cols) = matrix.shape();
331        if cols != vector.len() {
332            return Err(SparseError::DimensionMismatch {
333                expected: cols,
334                found: vector.len(),
335            });
336        }
337
338        // Choose kernel based on optimization level and device capabilities
339        let kernel = match optimization_level {
340            MetalOptimizationLevel::Basic => &self.kernel_handle,
341            MetalOptimizationLevel::SimdGroup => &self.simdgroup_kernel,
342            MetalOptimizationLevel::AppleSilicon => &self.apple_silicon_kernel,
343            MetalOptimizationLevel::NeuralEngine => &self.neural_engine_kernel,
344        };
345
346        if let Some(ref k) = kernel {
347            self.execute_kernel_with_optimization(matrix, vector, device, k, optimization_level)
348        } else {
349            // Fallback to basic kernel if specific optimization not available
350            if let Some(ref basic_kernel) = self.kernel_handle {
351                self.execute_kernel_with_optimization(
352                    matrix,
353                    vector,
354                    device,
355                    basic_kernel,
356                    MetalOptimizationLevel::Basic,
357                )
358            } else {
359                Err(SparseError::ComputationError(
360                    "No Metal kernels available".to_string(),
361                ))
362            }
363        }
364    }
365
366    #[cfg(feature = "gpu")]
367    fn execute_kernel_with_optimization<T>(
368        &self,
369        matrix: &CsrArray<T>,
370        vector: &ArrayView1<T>,
371        _device: &super::GpuDevice,
372        _kernel: &super::GpuKernelHandle,
373        optimization_level: MetalOptimizationLevel,
374    ) -> SparseResult<Array1<T>>
375    where
376        T: Float + SparseElement + Debug + Copy + super::GpuDataType,
377    {
378        let (rows, _) = matrix.shape();
379
380        if let Some(ref context) = self.context {
381            // Upload data to GPU using context
382            let indptr_gpu =
383                context.create_buffer_from_slice(matrix.get_indptr().as_slice().unwrap());
384            let indices_gpu =
385                context.create_buffer_from_slice(matrix.get_indices().as_slice().unwrap());
386            let data_gpu = context.create_buffer_from_slice(matrix.get_data().as_slice().unwrap());
387            let vector_gpu = context.create_buffer_from_slice(vector.as_slice().unwrap());
388            let result_gpu = context.create_buffer::<T>(rows);
389
390            // Configure launch parameters based on optimization level
391            let (threadgroup_size, _uses_shared_memory) = match optimization_level {
392                MetalOptimizationLevel::Basic => {
393                    (self.device_info.max_threadgroup_size.min(64), false)
394                }
395                MetalOptimizationLevel::SimdGroup => {
396                    (self.device_info.max_threadgroup_size.min(128), false)
397                }
398                MetalOptimizationLevel::AppleSilicon => {
399                    (self.device_info.max_threadgroup_size.min(256), true)
400                }
401                MetalOptimizationLevel::NeuralEngine => {
402                    // Optimize for Neural Engine pipeline
403                    (self.device_info.max_threadgroup_size.min(128), false)
404                }
405            };
406
407            let grid_size = (rows + threadgroup_size - 1) / threadgroup_size;
408
409            // Launch kernel using context
410            let args = vec![scirs2_core::gpu::DynamicKernelArg::U32(rows as u32)];
411
412            // Use appropriate kernel based on optimization level
413            let kernel_name = match optimization_level {
414                MetalOptimizationLevel::Basic => "spmv_csr_kernel",
415                MetalOptimizationLevel::SimdGroup => "spmv_csr_simdgroup_kernel",
416                MetalOptimizationLevel::AppleSilicon => "spmv_csr_apple_silicon_kernel",
417                MetalOptimizationLevel::NeuralEngine => "spmv_csr_neural_engine_kernel",
418            };
419
420            context
421                .launch_kernel(
422                    kernel_name,
423                    (grid_size, 1, 1),
424                    (threadgroup_size, 1, 1),
425                    &args,
426                )
427                .map_err(|e| {
428                    SparseError::ComputationError(format!("Metal kernel execution failed: {:?}", e))
429                })?;
430
431            // Download result
432            let mut result_vec = vec![T::sparse_zero(); rows];
433            result_gpu.copy_to_host(&mut result_vec).map_err(|e| {
434                SparseError::ComputationError(format!("Failed to copy result from GPU: {:?}", e))
435            })?;
436            Ok(Array1::from_vec(result_vec))
437        } else {
438            // Fallback to CPU implementation
439            matrix.dot_vector(vector)
440        }
441    }
442
443    /// Select optimal kernel based on device and matrix characteristics
444    #[cfg(feature = "gpu")]
445    fn select_optimal_kernel<T>(
446        &self,
447        rows: usize,
448        matrix: &CsrArray<T>,
449    ) -> SparseResult<super::GpuKernelHandle>
450    where
451        T: Float + SparseElement + Debug + Copy,
452    {
453        let avg_nnz_per_row = matrix.get_data().len() as f64 / rows as f64;
454
455        // Select kernel based on device capabilities and matrix characteristics
456        if self.device_info.is_apple_silicon && avg_nnz_per_row > 16.0 {
457            // Use Apple Silicon optimized kernel for dense-ish matrices
458            if let Some(ref kernel) = self.apple_silicon_kernel {
459                Ok(kernel.clone())
460            } else if let Some(ref kernel) = self.simdgroup_kernel {
461                Ok(kernel.clone())
462            } else if let Some(ref kernel) = self.kernel_handle {
463                Ok(kernel.clone())
464            } else {
465                Err(SparseError::ComputationError(
466                    "No Metal kernels available".to_string(),
467                ))
468            }
469        } else if self.device_info.supports_simdgroups && avg_nnz_per_row > 5.0 {
470            // Use SIMD group kernel for moderate sparsity
471            if let Some(ref kernel) = self.simdgroup_kernel {
472                Ok(kernel.clone())
473            } else if let Some(ref kernel) = self.kernel_handle {
474                Ok(kernel.clone())
475            } else {
476                Err(SparseError::ComputationError(
477                    "No Metal kernels available".to_string(),
478                ))
479            }
480        } else {
481            // Use basic kernel for very sparse matrices
482            if let Some(ref kernel) = self.kernel_handle {
483                Ok(kernel.clone())
484            } else {
485                Err(SparseError::ComputationError(
486                    "No Metal kernels available".to_string(),
487                ))
488            }
489        }
490    }
491
492    /// CPU fallback implementation
493    #[cfg(not(feature = "gpu"))]
494    pub fn execute_spmv_cpu<T>(
495        &self,
496        matrix: &CsrArray<T>,
497        vector: &ArrayView1<T>,
498    ) -> SparseResult<Array1<T>>
499    where
500        T: Float + SparseElement + Debug + Copy + std::iter::Sum,
501    {
502        matrix.dot_vector(vector)
503    }
504}
505
506impl Default for MetalSpMatVec {
507    fn default() -> Self {
508        Self::new().unwrap_or_else(|_| Self {
509            context: None,
510            kernel_handle: None,
511            simdgroup_kernel: None,
512            apple_silicon_kernel: None,
513            neural_engine_kernel: None,
514            device_info: MetalDeviceInfo::default(),
515        })
516    }
517}
518
519/// Metal optimization levels for sparse matrix operations
520#[derive(Debug, Clone, Copy, PartialEq, Eq)]
521pub enum MetalOptimizationLevel {
522    /// Basic thread-per-row implementation
523    Basic,
524    /// SIMD group optimized implementation
525    SimdGroup,
526    /// Apple Silicon specific optimizations
527    AppleSilicon,
528    /// Neural Engine preparation (future feature)
529    NeuralEngine,
530}
531
532impl Default for MetalOptimizationLevel {
533    fn default() -> Self {
534        Self::Basic
535    }
536}
537
538/// Metal device information for optimization
539#[derive(Debug)]
540pub struct MetalDeviceInfo {
541    pub max_threadgroup_size: usize,
542    pub shared_memory_size: usize,
543    pub supports_simdgroups: bool,
544    pub is_apple_silicon: bool,
545    pub has_neural_engine: bool,
546    pub device_name: String,
547}
548
549impl MetalDeviceInfo {
550    /// Detect Metal device capabilities
551    pub fn detect() -> Self {
552        // In a real implementation, this would query the Metal runtime
553        // For now, return sensible defaults for Apple Silicon
554        Self {
555            max_threadgroup_size: 1024,
556            shared_memory_size: 32768, // 32KB
557            supports_simdgroups: true,
558            is_apple_silicon: Self::detect_apple_silicon(),
559            has_neural_engine: Self::detect_neural_engine(),
560            device_name: "Apple GPU".to_string(),
561        }
562    }
563
564    fn detect_apple_silicon() -> bool {
565        // Simple detection based on architecture
566        #[cfg(all(target_arch = "aarch64", target_os = "macos"))]
567        {
568            true
569        }
570        #[cfg(not(all(target_arch = "aarch64", target_os = "macos")))]
571        {
572            return false;
573        }
574    }
575
576    fn detect_neural_engine() -> bool {
577        // Neural Engine is available on M1 and later
578        Self::detect_apple_silicon()
579    }
580}
581
582impl Default for MetalDeviceInfo {
583    fn default() -> Self {
584        Self::detect()
585    }
586}
587
588/// Metal memory management for sparse matrices
589pub struct MetalMemoryManager {
590    device_info: MetalDeviceInfo,
591    #[allow(dead_code)]
592    allocated_buffers: Vec<String>,
593}
594
595impl MetalMemoryManager {
596    /// Create a new Metal memory manager
597    pub fn new() -> Self {
598        Self {
599            device_info: MetalDeviceInfo::detect(),
600            allocated_buffers: Vec::new(),
601        }
602    }
603
604    /// Allocate GPU memory for sparse matrix data with Metal-specific optimizations
605    #[cfg(feature = "gpu")]
606    pub fn allocate_sparse_matrix<T>(
607        &mut self,
608        _matrix: &CsrArray<T>,
609        _device: &super::GpuDevice,
610    ) -> Result<MetalMatrixBuffers<T>, super::GpuError>
611    where
612        T: super::GpuDataType + Copy + Float + SparseElement + Debug,
613    {
614        // This functionality should use GpuContext instead of GpuDevice
615        // For now, return an error indicating this needs proper implementation
616        Err(super::GpuError::BackendNotImplemented(
617            super::GpuBackend::Metal,
618        ))
619    }
620
621    /// Get optimal threadgroup size for the current device
622    pub fn optimal_threadgroup_size(&self, problem_size: usize) -> usize {
623        let max_tg_size = self.device_info.max_threadgroup_size;
624
625        if self.device_info.is_apple_silicon {
626            // Apple Silicon prefers larger threadgroups
627            if problem_size < 1000 {
628                max_tg_size.min(128)
629            } else {
630                max_tg_size.min(256)
631            }
632        } else {
633            // Intel/AMD GPUs prefer smaller threadgroups
634            if problem_size < 1000 {
635                max_tg_size.min(64)
636            } else {
637                max_tg_size.min(128)
638            }
639        }
640    }
641
642    /// Check if SIMD group operations are beneficial
643    pub fn should_use_simdgroups<T>(&self, matrix: &CsrArray<T>) -> bool
644    where
645        T: Float + SparseElement + Debug + Copy,
646    {
647        if !self.device_info.supports_simdgroups {
648            return false;
649        }
650
651        let avg_nnz_per_row = matrix.nnz() as f64 / matrix.shape().0 as f64;
652
653        // SIMD groups are beneficial for matrices with moderate to high sparsity
654        avg_nnz_per_row >= 5.0
655    }
656}
657
658impl Default for MetalMemoryManager {
659    fn default() -> Self {
660        Self::new()
661    }
662}
663
664/// Metal storage modes for optimization
665#[derive(Debug, Clone, Copy, PartialEq, Eq)]
666pub enum MetalStorageMode {
667    /// Shared between CPU and GPU (Apple Silicon)
668    Shared,
669    /// Managed by Metal (discrete GPUs)
670    Managed,
671    /// Private to GPU only
672    Private,
673}
674
675/// GPU memory buffers for Metal sparse matrix data
676#[cfg(feature = "gpu")]
677pub struct MetalMatrixBuffers<T: super::GpuDataType> {
678    pub indptr: super::GpuBuffer<usize>,
679    pub indices: super::GpuBuffer<usize>,
680    pub data: super::GpuBuffer<T>,
681}
682
683#[cfg(test)]
684mod tests {
685    use super::*;
686
687    #[test]
688    fn test_metal_spmv_creation() {
689        let metal_spmv = MetalSpMatVec::new();
690        assert!(metal_spmv.is_ok());
691    }
692
693    #[test]
694    fn test_metal_optimization_levels() {
695        let basic = MetalOptimizationLevel::Basic;
696        let simdgroup = MetalOptimizationLevel::SimdGroup;
697        let apple_silicon = MetalOptimizationLevel::AppleSilicon;
698        let neural_engine = MetalOptimizationLevel::NeuralEngine;
699
700        assert_ne!(basic, simdgroup);
701        assert_ne!(simdgroup, apple_silicon);
702        assert_ne!(apple_silicon, neural_engine);
703        assert_eq!(
704            MetalOptimizationLevel::default(),
705            MetalOptimizationLevel::Basic
706        );
707    }
708
709    #[test]
710    fn test_metal_device_info() {
711        let info = MetalDeviceInfo::detect();
712        assert!(info.max_threadgroup_size > 0);
713        assert!(info.shared_memory_size > 0);
714        assert!(!info.device_name.is_empty());
715    }
716
717    #[test]
718    fn test_apple_silicon_detection() {
719        let info = MetalDeviceInfo::detect();
720
721        // Test that detection logic runs without errors
722        #[cfg(all(target_arch = "aarch64", target_os = "macos"))]
723        assert!(info.is_apple_silicon);
724
725        #[cfg(not(all(target_arch = "aarch64", target_os = "macos")))]
726        assert!(!info.is_apple_silicon);
727    }
728
729    #[test]
730    fn test_metal_memory_manager() {
731        let manager = MetalMemoryManager::new();
732        assert_eq!(manager.allocated_buffers.len(), 0);
733        assert!(manager.device_info.max_threadgroup_size > 0);
734
735        // Test threadgroup size selection
736        let tg_size_small = manager.optimal_threadgroup_size(500);
737        let tg_size_large = manager.optimal_threadgroup_size(50000);
738        assert!(tg_size_small > 0);
739        assert!(tg_size_large > 0);
740    }
741
742    #[test]
743    fn test_metal_storage_modes() {
744        let modes = [
745            MetalStorageMode::Shared,
746            MetalStorageMode::Managed,
747            MetalStorageMode::Private,
748        ];
749
750        for mode in &modes {
751            match mode {
752                MetalStorageMode::Shared => (),
753                MetalStorageMode::Managed => (),
754                MetalStorageMode::Private => (),
755            }
756        }
757    }
758
759    #[test]
760    #[allow(clippy::const_is_empty)]
761    fn test_shader_sources() {
762        assert!(!METAL_SPMV_SHADER_SOURCE.is_empty());
763        assert!(!METAL_APPLE_SILICON_SHADER_SOURCE.is_empty());
764
765        // Check that shaders contain expected function names
766        assert!(METAL_SPMV_SHADER_SOURCE.contains("spmv_csr_kernel"));
767        assert!(METAL_SPMV_SHADER_SOURCE.contains("spmv_csr_simdgroup_kernel"));
768        assert!(METAL_APPLE_SILICON_SHADER_SOURCE.contains("spmv_csr_apple_silicon_kernel"));
769        assert!(METAL_APPLE_SILICON_SHADER_SOURCE.contains("spmv_csr_neural_engine_prep_kernel"));
770    }
771}