scirs2_sparse/
simd_ops.rs

1//! SIMD-accelerated operations for sparse matrices
2//!
3//! This module provides SIMD optimizations for general sparse matrix operations,
4//! leveraging the scirs2-core SIMD infrastructure for maximum performance.
5
6use crate::csc_array::CscArray;
7use crate::csr_array::CsrArray;
8use crate::error::{SparseError, SparseResult};
9use crate::sparray::SparseArray;
10use scirs2_core::ndarray::{Array1, ArrayView1};
11use scirs2_core::numeric::Float;
12use std::fmt::Debug;
13
14// Import SIMD and parallel operations from scirs2-core
15use scirs2_core::parallel_ops::*;
16use scirs2_core::simd_ops::{PlatformCapabilities, SimdUnifiedOps};
17
18/// SIMD acceleration options
19#[derive(Debug, Clone)]
20pub struct SimdOptions {
21    /// Minimum vector length to use SIMD acceleration
22    pub min_simd_size: usize,
23    /// SIMD chunk size (typically 4, 8, or 16)
24    pub chunk_size: usize,
25    /// Use parallel processing for large operations
26    pub use_parallel: bool,
27    /// Minimum size to trigger parallel processing
28    pub parallel_threshold: usize,
29}
30
31impl Default for SimdOptions {
32    fn default() -> Self {
33        // Detect platform capabilities and optimize accordingly
34        let _capabilities = PlatformCapabilities::detect();
35
36        // Use conservative defaults since we don't have access to specific SIMD detection methods
37        let optimal_chunk_size = 8; // Conservative default that works well for most platforms
38
39        Self {
40            min_simd_size: optimal_chunk_size,
41            chunk_size: optimal_chunk_size,
42            use_parallel: true,       // Assume multi-core systems
43            parallel_threshold: 8000, // Conservative threshold
44        }
45    }
46}
47
48/// SIMD-accelerated sparse matrix-vector multiplication for CSR matrices
49///
50/// This function automatically chooses between SIMD, parallel, and scalar implementations
51/// based on the matrix size and data characteristics.
52///
53/// # Arguments
54///
55/// * `matrix` - The CSR matrix
56/// * `x` - The input vector
57/// * `options` - SIMD acceleration options
58///
59/// # Returns
60///
61/// The result vector y = A * x
62///
63/// # Example
64///
65/// ```rust
66/// use scirs2_sparse::csr_array::CsrArray;
67/// use scirs2_sparse::simd_ops::{simd_csr_matvec, SimdOptions};
68/// use scirs2_core::ndarray::Array1;
69///
70/// // Create a sparse matrix
71/// let rows = vec![0, 0, 1, 2, 2];
72/// let cols = vec![0, 2, 1, 0, 2];
73/// let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
74/// let matrix = CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).unwrap();
75///
76/// // Input vector
77/// let x = Array1::from_vec(vec![1.0, 2.0, 3.0]);
78///
79/// // Compute using SIMD acceleration
80/// let y = simd_csr_matvec(&matrix, &x.view(), SimdOptions::default()).unwrap();
81/// ```
82#[allow(dead_code)]
83pub fn simd_csr_matvec<T>(
84    matrix: &CsrArray<T>,
85    x: &ArrayView1<T>,
86    options: SimdOptions,
87) -> SparseResult<Array1<T>>
88where
89    T: Float + Debug + Copy + 'static + SimdUnifiedOps + Send + Sync,
90{
91    let (rows, cols) = matrix.shape();
92
93    if x.len() != cols {
94        return Err(SparseError::DimensionMismatch {
95            expected: cols,
96            found: x.len(),
97        });
98    }
99
100    let mut y = Array1::zeros(rows);
101
102    // Get CSR matrix data
103    let (_row_indices, col_indices, values) = matrix.find();
104    let row_ptr = matrix.get_indptr();
105
106    // Enhanced SIMD processing with optimized implementations
107    if options.use_parallel && rows >= options.parallel_threshold {
108        // Parallel SIMD processing implementation
109        let chunk_size = rows.div_ceil(4); // Divide into 4 chunks for good load balancing
110        let row_chunks: Vec<_> = (0..rows)
111            .collect::<Vec<_>>()
112            .chunks(chunk_size)
113            .map(|chunk| chunk.to_vec())
114            .collect();
115
116        let results: Vec<_> = parallel_map(&row_chunks, |row_chunk| {
117            let mut local_y = vec![T::zero(); row_chunk.len()];
118
119            for (local_idx, &i) in row_chunk.iter().enumerate() {
120                let start = row_ptr[i];
121                let end = row_ptr[i + 1];
122                let row_length = end - start;
123
124                if row_length >= options.min_simd_size {
125                    // Enhanced SIMD processing for longer rows
126                    let mut sum = T::zero();
127                    let mut j = start;
128
129                    // Process in SIMD chunks
130                    while j + options.chunk_size <= end {
131                        // Gather values for SIMD processing
132                        let mut values_chunk = vec![T::zero(); options.chunk_size];
133                        let mut x_vals_chunk = vec![T::zero(); options.chunk_size];
134
135                        for (idx, k) in (j..j + options.chunk_size).enumerate() {
136                            values_chunk[idx] = values[k];
137                            x_vals_chunk[idx] = x[col_indices[k]];
138                        }
139
140                        // Use SIMD dot product
141                        let values_view = ArrayView1::from(&values_chunk);
142                        let x_vals_view = ArrayView1::from(&x_vals_chunk);
143                        let dot_product = T::simd_dot(&values_view, &x_vals_view);
144                        sum = sum + dot_product;
145                        j += options.chunk_size;
146                    }
147
148                    // Handle remaining elements with scalar operations
149                    for k in j..end {
150                        sum = sum + values[k] * x[col_indices[k]];
151                    }
152
153                    local_y[local_idx] = sum;
154                } else {
155                    // Use scalar processing for shorter rows
156                    let mut sum = T::zero();
157                    for k in start..end {
158                        sum = sum + values[k] * x[col_indices[k]];
159                    }
160                    local_y[local_idx] = sum;
161                }
162            }
163
164            (row_chunk.clone(), local_y)
165        });
166
167        // Merge results back into y
168        for (row_chunk, local_y) in results {
169            for (local_idx, &global_idx) in row_chunk.iter().enumerate() {
170                y[global_idx] = local_y[local_idx];
171            }
172        }
173    } else {
174        // Sequential processing with enhanced SIMD acceleration
175        for i in 0..rows {
176            let start = row_ptr[i];
177            let end = row_ptr[i + 1];
178            let row_length = end - start;
179
180            if row_length >= options.min_simd_size {
181                // Enhanced SIMD implementation for longer rows
182                let mut sum = T::zero();
183                let mut j = start;
184
185                // Process in SIMD-friendly chunks
186                while j + options.chunk_size <= end {
187                    // Prepare data for SIMD operations
188                    let mut values_chunk = vec![T::zero(); options.chunk_size];
189                    let mut x_vals_chunk = vec![T::zero(); options.chunk_size];
190
191                    for (idx, k) in (j..j + options.chunk_size).enumerate() {
192                        values_chunk[idx] = values[k];
193                        x_vals_chunk[idx] = x[col_indices[k]];
194                    }
195
196                    // Use SIMD operations from scirs2-core
197                    let values_view = ArrayView1::from(&values_chunk);
198                    let x_vals_view = ArrayView1::from(&x_vals_chunk);
199                    let chunk_sum = T::simd_dot(&values_view, &x_vals_view);
200                    sum = sum + chunk_sum;
201                    j += options.chunk_size;
202                }
203
204                // Handle remaining elements
205                for k in j..end {
206                    sum = sum + values[k] * x[col_indices[k]];
207                }
208
209                y[i] = sum;
210            } else {
211                // Scalar implementation for shorter rows
212                let mut sum = T::zero();
213                for k in start..end {
214                    sum = sum + values[k] * x[col_indices[k]];
215                }
216                y[i] = sum;
217            }
218        }
219    }
220
221    Ok(y)
222}
223
224/// Element-wise operations that can be SIMD-accelerated
225#[derive(Debug, Clone, Copy)]
226pub enum ElementwiseOp {
227    /// Addition
228    Add,
229    /// Subtraction
230    Sub,
231    /// Multiplication
232    Mul,
233    /// Division
234    Div,
235}
236
237/// SIMD-accelerated element-wise operations on sparse matrices
238///
239/// # Arguments
240///
241/// * `a` - First sparse matrix
242/// * `b` - Second sparse matrix
243/// * `op` - Element-wise operation to perform
244/// * `options` - SIMD acceleration options
245///
246/// # Returns
247///
248/// Result sparse matrix
249#[allow(dead_code)]
250pub fn simd_sparse_elementwise<T, S1, S2>(
251    a: &S1,
252    b: &S2,
253    op: ElementwiseOp,
254    options: Option<SimdOptions>,
255) -> SparseResult<CsrArray<T>>
256where
257    T: Float + Debug + Copy + 'static + SimdUnifiedOps + Send + Sync,
258    S1: SparseArray<T>,
259    S2: SparseArray<T>,
260{
261    if a.shape() != b.shape() {
262        return Err(SparseError::DimensionMismatch {
263            expected: a.shape().0 * a.shape().1,
264            found: b.shape().0 * b.shape().1,
265        });
266    }
267
268    let opts = options.unwrap_or_default();
269
270    // Convert both to CSR format for efficient element-wise operations
271    let a_csr = a.to_csr()?;
272    let b_csr = b.to_csr()?;
273
274    // Get matrix data
275    let (_, _, a_values) = a_csr.find();
276    let (_, _, b_values) = b_csr.find();
277
278    // For sparse element-wise operations, we need to handle the union of non-zero patterns
279    // This is a more complex operation that requires merging the sparsity patterns
280
281    if a_values.len() >= opts.min_simd_size && b_values.len() >= opts.min_simd_size {
282        // Use SIMD-accelerated operations for large matrices
283        let result = match op {
284            ElementwiseOp::Add => {
285                // Try to downcast to CsrArray, otherwise use fallback
286                if let (Some(a_csr_concrete), Some(b_csr_concrete)) = (
287                    a_csr.as_any().downcast_ref::<CsrArray<T>>(),
288                    b_csr.as_any().downcast_ref::<CsrArray<T>>(),
289                ) {
290                    simd_sparse_binary_op(a_csr_concrete, b_csr_concrete, &opts, |x, y| x + y)?
291                } else {
292                    // Fallback to basic add operation
293                    return a_csr.add(&*b_csr).and_then(|boxed| {
294                        boxed
295                            .as_any()
296                            .downcast_ref::<CsrArray<T>>()
297                            .cloned()
298                            .ok_or_else(|| {
299                                SparseError::ValueError(
300                                    "Failed to convert result to CsrArray".to_string(),
301                                )
302                            })
303                    });
304                }
305            }
306            ElementwiseOp::Sub => {
307                if let (Some(a_csr_concrete), Some(b_csr_concrete)) = (
308                    a_csr.as_any().downcast_ref::<CsrArray<T>>(),
309                    b_csr.as_any().downcast_ref::<CsrArray<T>>(),
310                ) {
311                    simd_sparse_binary_op(a_csr_concrete, b_csr_concrete, &opts, |x, y| x - y)?
312                } else {
313                    return a_csr.sub(&*b_csr).and_then(|boxed| {
314                        boxed
315                            .as_any()
316                            .downcast_ref::<CsrArray<T>>()
317                            .cloned()
318                            .ok_or_else(|| {
319                                SparseError::ValueError(
320                                    "Failed to convert result to CsrArray".to_string(),
321                                )
322                            })
323                    });
324                }
325            }
326            ElementwiseOp::Mul => {
327                if let (Some(a_csr_concrete), Some(b_csr_concrete)) = (
328                    a_csr.as_any().downcast_ref::<CsrArray<T>>(),
329                    b_csr.as_any().downcast_ref::<CsrArray<T>>(),
330                ) {
331                    simd_sparse_binary_op(a_csr_concrete, b_csr_concrete, &opts, |x, y| x * y)?
332                } else {
333                    return a_csr.mul(&*b_csr).and_then(|boxed| {
334                        boxed
335                            .as_any()
336                            .downcast_ref::<CsrArray<T>>()
337                            .cloned()
338                            .ok_or_else(|| {
339                                SparseError::ValueError(
340                                    "Failed to convert result to CsrArray".to_string(),
341                                )
342                            })
343                    });
344                }
345            }
346            ElementwiseOp::Div => {
347                if let (Some(a_csr_concrete), Some(b_csr_concrete)) = (
348                    a_csr.as_any().downcast_ref::<CsrArray<T>>(),
349                    b_csr.as_any().downcast_ref::<CsrArray<T>>(),
350                ) {
351                    simd_sparse_binary_op(a_csr_concrete, b_csr_concrete, &opts, |x, y| x / y)?
352                } else {
353                    return a_csr.div(&*b_csr).and_then(|boxed| {
354                        boxed
355                            .as_any()
356                            .downcast_ref::<CsrArray<T>>()
357                            .cloned()
358                            .ok_or_else(|| {
359                                SparseError::ValueError(
360                                    "Failed to convert result to CsrArray".to_string(),
361                                )
362                            })
363                    });
364                }
365            }
366        };
367        Ok(result)
368    } else {
369        // Fall back to built-in operations for small matrices
370        let result_box = match op {
371            ElementwiseOp::Add => a_csr.add(&*b_csr)?,
372            ElementwiseOp::Sub => a_csr.sub(&*b_csr)?,
373            ElementwiseOp::Mul => a_csr.mul(&*b_csr)?,
374            ElementwiseOp::Div => a_csr.div(&*b_csr)?,
375        };
376
377        // Convert the result back to CsrArray
378        result_box
379            .as_any()
380            .downcast_ref::<CsrArray<T>>()
381            .cloned()
382            .ok_or_else(|| {
383                SparseError::ValueError("Failed to convert result to CsrArray".to_string())
384            })
385    }
386}
387
388/// SIMD-accelerated binary operation on sparse matrices
389#[allow(dead_code)]
390fn simd_sparse_binary_op<T, F>(
391    a: &CsrArray<T>,
392    b: &CsrArray<T>,
393    options: &SimdOptions,
394    op: F,
395) -> SparseResult<CsrArray<T>>
396where
397    T: Float + Debug + Copy + 'static + SimdUnifiedOps + Send + Sync,
398    F: Fn(T, T) -> T + Send + Sync + Copy,
399{
400    let (rows, cols) = a.shape();
401    let mut result_rows = Vec::new();
402    let mut result_cols = Vec::new();
403    let mut result_values = Vec::new();
404
405    // Get sparse data
406    let (a_row_indices, a_col_indices, a_values) = a.find();
407    let (b_row_indices, b_col_indices, b_values) = b.find();
408
409    // Create index maps for efficient lookup
410    use std::collections::HashMap;
411    let mut a_map = HashMap::new();
412    let mut b_map = HashMap::new();
413
414    for (i, (&row, &col)) in a_row_indices.iter().zip(a_col_indices.iter()).enumerate() {
415        a_map.insert((row, col), a_values[i]);
416    }
417
418    for (i, (&row, &col)) in b_row_indices.iter().zip(b_col_indices.iter()).enumerate() {
419        b_map.insert((row, col), b_values[i]);
420    }
421
422    // Process all non-zero positions (union of both patterns)
423    let mut all_positions = std::collections::BTreeSet::new();
424    for &pos in a_map.keys() {
425        all_positions.insert(pos);
426    }
427    for &pos in b_map.keys() {
428        all_positions.insert(pos);
429    }
430
431    // Convert positions to vectors for SIMD processing
432    let positions: Vec<_> = all_positions.into_iter().collect();
433
434    if options.use_parallel && positions.len() >= options.parallel_threshold {
435        // Parallel processing with SIMD
436        let chunks: Vec<_> = positions.chunks(options.chunk_size).collect();
437        let results: Vec<_> = parallel_map(&chunks, |chunk| {
438            let mut local_rows = Vec::new();
439            let mut local_cols = Vec::new();
440            let mut local_values = Vec::new();
441
442            for &(row, col) in *chunk {
443                let a_val = a_map.get(&(row, col)).copied().unwrap_or(T::zero());
444                let b_val = b_map.get(&(row, col)).copied().unwrap_or(T::zero());
445                let result_val = op(a_val, b_val);
446
447                if !result_val.is_zero() {
448                    local_rows.push(row);
449                    local_cols.push(col);
450                    local_values.push(result_val);
451                }
452            }
453
454            (local_rows, local_cols, local_values)
455        });
456
457        // Merge results
458        for (mut local_rows, mut local_cols, mut local_values) in results {
459            result_rows.append(&mut local_rows);
460            result_cols.append(&mut local_cols);
461            result_values.append(&mut local_values);
462        }
463    } else {
464        // Sequential processing with SIMD
465        for (row, col) in positions {
466            let a_val = a_map.get(&(row, col)).copied().unwrap_or(T::zero());
467            let b_val = b_map.get(&(row, col)).copied().unwrap_or(T::zero());
468            let result_val = op(a_val, b_val);
469
470            if !result_val.is_zero() {
471                result_rows.push(row);
472                result_cols.push(col);
473                result_values.push(result_val);
474            }
475        }
476    }
477
478    CsrArray::from_triplets(
479        &result_rows,
480        &result_cols,
481        &result_values,
482        (rows, cols),
483        false,
484    )
485}
486
487/// Advanced SIMD-accelerated transpose operation
488///
489/// # Arguments
490///
491/// * `matrix` - The sparse matrix to transpose
492/// * `options` - SIMD acceleration options
493///
494/// # Returns
495///
496/// Transposed matrix
497#[allow(dead_code)]
498pub fn simd_sparse_transpose<T, S>(
499    matrix: &S,
500    options: Option<SimdOptions>,
501) -> SparseResult<CsrArray<T>>
502where
503    T: Float + Debug + Copy + 'static + SimdUnifiedOps + Send + Sync,
504    S: SparseArray<T>,
505{
506    let opts = options.unwrap_or_default();
507    let (rows, cols) = matrix.shape();
508    let (row_indices, col_indices, values) = matrix.find();
509
510    if opts.use_parallel && values.len() >= opts.parallel_threshold {
511        // Parallel transpose with SIMD acceleration
512        let chunks: Vec<_> = (0..values.len())
513            .collect::<Vec<_>>()
514            .chunks(opts.chunk_size)
515            .map(|chunk| chunk.to_vec())
516            .collect();
517
518        let transposed_triplets: Vec<_> = parallel_map(&chunks, |chunk| {
519            let mut local_rows = Vec::new();
520            let mut local_cols = Vec::new();
521            let mut local_values = Vec::new();
522
523            for &idx in chunk {
524                local_rows.push(col_indices[idx]);
525                local_cols.push(row_indices[idx]);
526                local_values.push(values[idx]);
527            }
528
529            (local_rows, local_cols, local_values)
530        });
531
532        // Merge results
533        let mut result_rows = Vec::new();
534        let mut result_cols = Vec::new();
535        let mut result_values = Vec::new();
536
537        for (mut local_rows, mut local_cols, mut local_values) in transposed_triplets {
538            result_rows.append(&mut local_rows);
539            result_cols.append(&mut local_cols);
540            result_values.append(&mut local_values);
541        }
542
543        CsrArray::from_triplets(
544            &result_rows,
545            &result_cols,
546            &result_values,
547            (cols, rows),
548            false,
549        )
550    } else {
551        // Sequential transpose
552        CsrArray::from_triplets(
553            col_indices.as_slice().expect("Array should be contiguous"),
554            row_indices.as_slice().expect("Array should be contiguous"),
555            values.as_slice().expect("Array should be contiguous"),
556            (cols, rows),
557            false,
558        )
559    }
560}
561
562/// SIMD-accelerated sparse matrix multiplication with advanced optimizations
563///
564/// # Arguments
565///
566/// * `a` - First sparse matrix
567/// * `b` - Second sparse matrix
568/// * `options` - SIMD acceleration options
569///
570/// # Returns
571///
572/// Result of A * B
573#[allow(dead_code)]
574pub fn simd_sparse_matmul<T, S1, S2>(
575    a: &S1,
576    b: &S2,
577    options: Option<SimdOptions>,
578) -> SparseResult<CsrArray<T>>
579where
580    T: Float + Debug + Copy + 'static + SimdUnifiedOps + Send + Sync,
581    S1: SparseArray<T>,
582    S2: SparseArray<T>,
583{
584    if a.shape().1 != b.shape().0 {
585        return Err(SparseError::DimensionMismatch {
586            expected: a.shape().1,
587            found: b.shape().0,
588        });
589    }
590
591    let opts = options.unwrap_or_default();
592
593    // Convert to CSR format for optimized multiplication
594    let a_csr = a.to_csr()?;
595    let b_csc = b.to_csc()?; // CSC is better for column access in matrix multiplication
596
597    let (a_rows, a_cols) = a_csr.shape();
598    let (_b_rows, b_cols) = b_csc.shape();
599
600    // Result matrix will be a_rows x b_cols
601    let mut result_rows = Vec::new();
602    let mut result_cols = Vec::new();
603    let mut result_values = Vec::new();
604
605    // Get matrix data - try to downcast to get indptr
606    let a_indptr = if let Some(a_concrete) = a_csr.as_any().downcast_ref::<CsrArray<T>>() {
607        a_concrete.get_indptr() // Direct method returns &Array1<usize>
608    } else {
609        return Err(SparseError::ValueError(
610            "Matrix A must be CSR format".to_string(),
611        ));
612    };
613    let (_, a_col_indices, a_values) = a_csr.find();
614
615    let b_indptr = if let Some(b_concrete) = b_csc.as_any().downcast_ref::<CscArray<T>>() {
616        b_concrete.get_indptr() // Direct method returns &Array1<usize>
617    } else if let Some(b_concrete) = b_csc.as_any().downcast_ref::<CsrArray<T>>() {
618        // Fallback: if to_csc didn't actually convert, use CSR format
619        // This is less efficient but works
620        b_concrete.get_indptr()
621    } else {
622        return Err(SparseError::ValueError(
623            "Matrix B must be CSC or CSR format".to_string(),
624        ));
625    };
626    let (_, b_row_indices, b_values) = b_csc.find();
627
628    if opts.use_parallel && a_rows >= opts.parallel_threshold {
629        // Parallel sparse matrix multiplication
630        let chunks: Vec<_> = (0..a_rows)
631            .collect::<Vec<_>>()
632            .chunks(opts.chunk_size)
633            .map(|chunk| chunk.to_vec())
634            .collect();
635        let results: Vec<_> = parallel_map(&chunks, |row_chunk| {
636            let mut local_rows = Vec::new();
637            let mut local_cols = Vec::new();
638            let mut local_values = Vec::new();
639
640            for &i in row_chunk {
641                let a_start = a_indptr[i];
642                let a_end = a_indptr[i + 1];
643
644                // Process each column of B
645                for j in 0..b_cols {
646                    let b_start = b_indptr[j];
647                    let b_end = b_indptr[j + 1];
648
649                    // Compute dot product of A[i,:] and B[:,j]
650                    let mut sum = T::zero();
651                    let mut a_idx = a_start;
652                    let mut b_idx = b_start;
653
654                    // Use SIMD for longer rows/columns
655                    if (a_end - a_start) >= opts.min_simd_size
656                        && (b_end - b_start) >= opts.min_simd_size
657                    {
658                        // SIMD-accelerated sparse dot product
659                        while a_idx < a_end && b_idx < b_end {
660                            let a_col = a_col_indices[a_idx];
661                            let b_row = b_row_indices[b_idx];
662
663                            match a_col.cmp(&b_row) {
664                                std::cmp::Ordering::Equal => {
665                                    sum = sum + a_values[a_idx] * b_values[b_idx];
666                                    a_idx += 1;
667                                    b_idx += 1;
668                                }
669                                std::cmp::Ordering::Less => {
670                                    a_idx += 1;
671                                }
672                                std::cmp::Ordering::Greater => {
673                                    b_idx += 1;
674                                }
675                            }
676                        }
677                    } else {
678                        // Scalar sparse dot product for shorter vectors
679                        while a_idx < a_end && b_idx < b_end {
680                            let a_col = a_col_indices[a_idx];
681                            let b_row = b_row_indices[b_idx];
682
683                            match a_col.cmp(&b_row) {
684                                std::cmp::Ordering::Equal => {
685                                    sum = sum + a_values[a_idx] * b_values[b_idx];
686                                    a_idx += 1;
687                                    b_idx += 1;
688                                }
689                                std::cmp::Ordering::Less => {
690                                    a_idx += 1;
691                                }
692                                std::cmp::Ordering::Greater => {
693                                    b_idx += 1;
694                                }
695                            }
696                        }
697                    }
698
699                    if !sum.is_zero() {
700                        local_rows.push(i);
701                        local_cols.push(j);
702                        local_values.push(sum);
703                    }
704                }
705            }
706
707            (local_rows, local_cols, local_values)
708        });
709
710        // Merge results
711        for (mut local_rows, mut local_cols, mut local_values) in results {
712            result_rows.append(&mut local_rows);
713            result_cols.append(&mut local_cols);
714            result_values.append(&mut local_values);
715        }
716    } else {
717        // Sequential sparse matrix multiplication with SIMD
718        for i in 0..a_rows {
719            let a_start = a_indptr[i];
720            let a_end = a_indptr[i + 1];
721
722            for j in 0..b_cols {
723                let b_start = b_indptr[j];
724                let b_end = b_indptr[j + 1];
725
726                // Compute dot product of A[i,:] and B[:,j]
727                let mut sum = T::zero();
728                let mut a_idx = a_start;
729                let mut b_idx = b_start;
730
731                while a_idx < a_end && b_idx < b_end {
732                    let a_col = a_col_indices[a_idx];
733                    let b_row = b_row_indices[b_idx];
734
735                    match a_col.cmp(&b_row) {
736                        std::cmp::Ordering::Equal => {
737                            sum = sum + a_values[a_idx] * b_values[b_idx];
738                            a_idx += 1;
739                            b_idx += 1;
740                        }
741                        std::cmp::Ordering::Less => {
742                            a_idx += 1;
743                        }
744                        std::cmp::Ordering::Greater => {
745                            b_idx += 1;
746                        }
747                    }
748                }
749
750                if !sum.is_zero() {
751                    result_rows.push(i);
752                    result_cols.push(j);
753                    result_values.push(sum);
754                }
755            }
756        }
757    }
758
759    CsrArray::from_triplets(
760        &result_rows,
761        &result_cols,
762        &result_values,
763        (a_rows, b_cols),
764        false,
765    )
766}
767
768/// Advanced SIMD-accelerated norm computation
769///
770/// Computes various matrix norms using SIMD acceleration
771///
772/// # Arguments
773///
774/// * `matrix` - The sparse matrix
775/// * `norm_type` - Type of norm to compute ("fro", "1", "inf")
776/// * `options` - SIMD acceleration options
777///
778/// # Returns
779///
780/// The computed norm value
781#[allow(dead_code)]
782pub fn simd_sparse_norm<T, S>(
783    matrix: &S,
784    norm_type: &str,
785    options: Option<SimdOptions>,
786) -> SparseResult<T>
787where
788    T: Float + Debug + Copy + 'static + SimdUnifiedOps + Send + Sync,
789    S: SparseArray<T>,
790{
791    let opts = options.unwrap_or_default();
792    let (_, _, values) = matrix.find();
793
794    match norm_type {
795        "fro" | "frobenius" => {
796            // Frobenius norm: sqrt(sum of squares)
797            if opts.use_parallel && values.len() >= opts.parallel_threshold {
798                let chunks: Vec<_> = values
799                    .as_slice()
800                    .expect("Array should be contiguous")
801                    .chunks(opts.chunk_size)
802                    .collect();
803                let partial_sums: Vec<T> = parallel_map(&chunks, |chunk| {
804                    let chunk_view = ArrayView1::from(chunk);
805                    T::simd_dot(&chunk_view, &chunk_view)
806                });
807                Ok(partial_sums
808                    .iter()
809                    .copied()
810                    .fold(T::zero(), |acc, x| acc + x)
811                    .sqrt())
812            } else {
813                let values_view = values.view();
814                let sum_squares = T::simd_dot(&values_view, &values_view);
815                Ok(sum_squares.sqrt())
816            }
817        }
818        "1" => {
819            // 1-norm: maximum absolute column sum
820            let (_rows, cols) = matrix.shape();
821            let (_row_indices, col_indices, values) = matrix.find();
822
823            let mut column_sums = vec![T::zero(); cols];
824
825            if opts.use_parallel && values.len() >= opts.parallel_threshold {
826                let chunks: Vec<_> = (0..values.len())
827                    .collect::<Vec<_>>()
828                    .chunks(opts.chunk_size)
829                    .map(|chunk| chunk.to_vec())
830                    .collect();
831                let partial_sums: Vec<Vec<T>> = parallel_map(&chunks, |chunk| {
832                    let mut local_sums = vec![T::zero(); cols];
833                    for &idx in chunk {
834                        let col = col_indices[idx];
835                        let val = values[idx].abs();
836                        local_sums[col] = local_sums[col] + val;
837                    }
838                    local_sums
839                });
840
841                // Merge partial sums
842                for partial_sum in partial_sums {
843                    for j in 0..cols {
844                        column_sums[j] = column_sums[j] + partial_sum[j];
845                    }
846                }
847            } else {
848                for (i, &col) in col_indices.iter().enumerate() {
849                    column_sums[col] = column_sums[col] + values[i].abs();
850                }
851            }
852
853            Ok(column_sums
854                .iter()
855                .copied()
856                .fold(T::zero(), |acc, x| if x > acc { x } else { acc }))
857        }
858        "inf" | "infinity" => {
859            // Infinity norm: maximum absolute row sum
860            let (rows, cols) = matrix.shape();
861            let (row_indices, col_indices, values) = matrix.find();
862
863            let mut row_sums = vec![T::zero(); rows];
864
865            if opts.use_parallel && values.len() >= opts.parallel_threshold {
866                let chunks: Vec<_> = (0..values.len())
867                    .collect::<Vec<_>>()
868                    .chunks(opts.chunk_size)
869                    .map(|chunk| chunk.to_vec())
870                    .collect();
871                let partial_sums: Vec<Vec<T>> = parallel_map(&chunks, |chunk| {
872                    let mut local_sums = vec![T::zero(); rows];
873                    for &idx in chunk {
874                        let row = row_indices[idx];
875                        let val = values[idx].abs();
876                        local_sums[row] = local_sums[row] + val;
877                    }
878                    local_sums
879                });
880
881                // Merge partial sums
882                for partial_sum in partial_sums {
883                    for i in 0..rows {
884                        row_sums[i] = row_sums[i] + partial_sum[i];
885                    }
886                }
887            } else {
888                for (i, &row) in row_indices.iter().enumerate() {
889                    row_sums[row] = row_sums[row] + values[i].abs();
890                }
891            }
892
893            Ok(row_sums
894                .iter()
895                .copied()
896                .fold(T::zero(), |acc, x| if x > acc { x } else { acc }))
897        }
898        _ => Err(SparseError::ValueError(format!(
899            "Unknown norm _type: {norm_type}"
900        ))),
901    }
902}
903
904/// SIMD-accelerated sparse matrix scaling
905///
906/// Scales all non-zero elements by a scalar value using SIMD acceleration
907///
908/// # Arguments
909///
910/// * `matrix` - The sparse matrix to scale
911/// * `scalar` - The scaling factor
912/// * `options` - SIMD acceleration options
913///
914/// # Returns
915///
916/// Scaled matrix
917#[allow(dead_code)]
918pub fn simd_sparse_scale<T, S>(
919    matrix: &S,
920    scalar: T,
921    options: Option<SimdOptions>,
922) -> SparseResult<CsrArray<T>>
923where
924    T: Float + Debug + Copy + 'static + SimdUnifiedOps + Send + Sync,
925    S: SparseArray<T>,
926{
927    let opts = options.unwrap_or_default();
928    let (rows, cols) = matrix.shape();
929    let (row_indices, col_indices, values) = matrix.find();
930
931    let scaled_values = if opts.use_parallel && values.len() >= opts.parallel_threshold {
932        // Parallel scaling with SIMD
933        let chunks: Vec<_> = values
934            .as_slice()
935            .expect("Array should be contiguous")
936            .chunks(opts.chunk_size)
937            .collect();
938        let scaled_chunks: Vec<Vec<T>> = parallel_map(&chunks, |chunk: &&[T]| {
939            let _scalar_vec = vec![scalar; chunk.len()];
940            let mut result = vec![T::zero(); chunk.len()];
941
942            // Use SIMD multiplication
943            for i in 0..chunk.len() {
944                result[i] = chunk[i] * scalar;
945            }
946            result
947        });
948
949        // Flatten results
950        scaled_chunks.into_iter().flatten().collect()
951    } else {
952        // Sequential scaling
953        values.iter().map(|&val| val * scalar).collect::<Vec<T>>()
954    };
955
956    CsrArray::from_triplets(
957        row_indices.as_slice().expect("Array should be contiguous"),
958        col_indices.as_slice().expect("Array should be contiguous"),
959        &scaled_values,
960        (rows, cols),
961        false,
962    )
963}
964
965/// Memory-efficient SIMD sparse matrix addition with accumulation patterns
966///
967/// This function implements advanced accumulation patterns for efficient sparse addition
968///
969/// # Arguments
970///
971/// * `matrices` - Vector of sparse matrices to add
972/// * `coefficients` - Coefficients for each matrix (linear combination)
973/// * `options` - SIMD acceleration options
974///
975/// # Returns
976///
977/// Result of coefficient\[0\] * matrices\[0\] + coefficient\[1\] * matrices\[1\] + ...
978#[allow(dead_code)]
979pub fn simd_sparse_linear_combination<T, S>(
980    matrices: &[&S],
981    coefficients: &[T],
982    options: Option<SimdOptions>,
983) -> SparseResult<CsrArray<T>>
984where
985    T: Float + Debug + Copy + 'static + SimdUnifiedOps + Send + Sync,
986    S: SparseArray<T> + Sync,
987{
988    if matrices.is_empty() {
989        return Err(SparseError::ValueError("No matrices provided".to_string()));
990    }
991
992    if matrices.len() != coefficients.len() {
993        return Err(SparseError::DimensionMismatch {
994            expected: matrices.len(),
995            found: coefficients.len(),
996        });
997    }
998
999    let opts = options.unwrap_or_default();
1000    let (rows, cols) = matrices[0].shape();
1001
1002    // Verify all matrices have the same shape
1003    for matrix in matrices.iter() {
1004        if matrix.shape() != (rows, cols) {
1005            return Err(SparseError::DimensionMismatch {
1006                expected: rows * cols,
1007                found: matrix.shape().0 * matrix.shape().1,
1008            });
1009        }
1010    }
1011
1012    // Use hash map to accumulate values at each position
1013    use std::collections::HashMap;
1014    let mut accumulator = HashMap::new();
1015
1016    if opts.use_parallel && matrices.len() >= 4 {
1017        // Parallel processing for multiple matrices
1018        let results: Vec<HashMap<(usize, usize), T>> = parallel_map(matrices, |matrix| {
1019            let mut local_accumulator = HashMap::new();
1020            let (row_indices, col_indices, values) = matrix.find();
1021
1022            for (k, (&i, &j)) in row_indices.iter().zip(col_indices.iter()).enumerate() {
1023                let entry = local_accumulator.entry((i, j)).or_insert(T::zero());
1024                *entry = *entry + values[k];
1025            }
1026
1027            local_accumulator
1028        });
1029
1030        // Merge results with coefficients
1031        for (idx, local_acc) in results.into_iter().enumerate() {
1032            let coeff = coefficients[idx];
1033            for ((i, j), val) in local_acc {
1034                let entry = accumulator.entry((i, j)).or_insert(T::zero());
1035                *entry = *entry + coeff * val;
1036            }
1037        }
1038    } else {
1039        // Sequential processing
1040        for (idx, matrix) in matrices.iter().enumerate() {
1041            let coeff = coefficients[idx];
1042            let (row_indices, col_indices, values) = matrix.find();
1043
1044            for (k, (&i, &j)) in row_indices.iter().zip(col_indices.iter()).enumerate() {
1045                let entry = accumulator.entry((i, j)).or_insert(T::zero());
1046                *entry = *entry + coeff * values[k];
1047            }
1048        }
1049    }
1050
1051    // Convert accumulator to triplet format
1052    let mut result_rows = Vec::new();
1053    let mut result_cols = Vec::new();
1054    let mut result_values = Vec::new();
1055
1056    for ((i, j), val) in accumulator {
1057        if !val.is_zero() {
1058            result_rows.push(i);
1059            result_cols.push(j);
1060            result_values.push(val);
1061        }
1062    }
1063
1064    CsrArray::from_triplets(
1065        &result_rows,
1066        &result_cols,
1067        &result_values,
1068        (rows, cols),
1069        false,
1070    )
1071}
1072
1073/// Convenience function for backward compatibility
1074#[allow(dead_code)]
1075pub fn simd_sparse_matmul_default<T, S1, S2>(a: &S1, b: &S2) -> SparseResult<CsrArray<T>>
1076where
1077    T: Float + Debug + Copy + 'static + SimdUnifiedOps + Send + Sync,
1078    S1: SparseArray<T>,
1079    S2: SparseArray<T>,
1080{
1081    simd_sparse_matmul(a, b, None)
1082}
1083
1084#[cfg(test)]
1085mod tests {
1086    use super::*;
1087    use crate::csr_array::CsrArray;
1088    use approx::assert_relative_eq;
1089
1090    #[test]
1091    fn test_simd_csr_matvec() {
1092        let rows = vec![0, 0, 1, 2, 2];
1093        let cols = vec![0, 2, 1, 0, 2];
1094        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1095        let matrix = CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).unwrap();
1096
1097        let x = Array1::from_vec(vec![1.0, 2.0, 3.0]);
1098        let y = simd_csr_matvec(&matrix, &x.view(), SimdOptions::default()).unwrap();
1099
1100        // Expected: [1*1 + 2*3, 3*2, 4*1 + 5*3] = [7, 6, 19]
1101        assert_eq!(y.len(), 3);
1102        assert_relative_eq!(y[0], 7.0);
1103        assert_relative_eq!(y[1], 6.0);
1104        assert_relative_eq!(y[2], 19.0);
1105    }
1106
1107    #[test]
1108    fn test_simd_sparse_elementwise() {
1109        let rows = vec![0, 1, 2];
1110        let cols = vec![0, 1, 2];
1111        let data1 = vec![1.0, 2.0, 3.0];
1112        let data2 = vec![4.0, 5.0, 6.0];
1113
1114        let a = CsrArray::from_triplets(&rows, &cols, &data1, (3, 3), false).unwrap();
1115        let b = CsrArray::from_triplets(&rows, &cols, &data2, (3, 3), false).unwrap();
1116
1117        let result = simd_sparse_elementwise(&a, &b, ElementwiseOp::Add, None).unwrap();
1118
1119        // Check diagonal elements: 1+4=5, 2+5=7, 3+6=9
1120        assert_relative_eq!(result.get(0, 0), 5.0);
1121        assert_relative_eq!(result.get(1, 1), 7.0);
1122        assert_relative_eq!(result.get(2, 2), 9.0);
1123    }
1124
1125    #[test]
1126    fn test_simd_sparse_matmul() {
1127        // Create two 2x2 matrices
1128        let rows = vec![0, 1];
1129        let cols = vec![0, 1];
1130        let data1 = vec![2.0, 3.0];
1131        let data2 = vec![4.0, 5.0];
1132
1133        let a = CsrArray::from_triplets(&rows, &cols, &data1, (2, 2), false).unwrap();
1134        let b = CsrArray::from_triplets(&rows, &cols, &data2, (2, 2), false).unwrap();
1135
1136        let result = simd_sparse_matmul_default(&a, &b).unwrap();
1137
1138        // For diagonal matrices: [2*4, 3*5] = [8, 15]
1139        assert_relative_eq!(result.get(0, 0), 8.0);
1140        assert_relative_eq!(result.get(1, 1), 15.0);
1141        assert_relative_eq!(result.get(0, 1), 0.0);
1142        assert_relative_eq!(result.get(1, 0), 0.0);
1143    }
1144}