1use crate::csc_array::CscArray;
7use crate::csr_array::CsrArray;
8use crate::error::{SparseError, SparseResult};
9use crate::sparray::SparseArray;
10use scirs2_core::ndarray::{Array1, ArrayView1};
11use scirs2_core::numeric::Float;
12use std::fmt::Debug;
13
14use scirs2_core::parallel_ops::*;
16use scirs2_core::simd_ops::{PlatformCapabilities, SimdUnifiedOps};
17
18#[derive(Debug, Clone)]
20pub struct SimdOptions {
21 pub min_simd_size: usize,
23 pub chunk_size: usize,
25 pub use_parallel: bool,
27 pub parallel_threshold: usize,
29}
30
31impl Default for SimdOptions {
32 fn default() -> Self {
33 let _capabilities = PlatformCapabilities::detect();
35
36 let optimal_chunk_size = 8; Self {
40 min_simd_size: optimal_chunk_size,
41 chunk_size: optimal_chunk_size,
42 use_parallel: true, parallel_threshold: 8000, }
45 }
46}
47
48#[allow(dead_code)]
83pub fn simd_csr_matvec<T>(
84 matrix: &CsrArray<T>,
85 x: &ArrayView1<T>,
86 options: SimdOptions,
87) -> SparseResult<Array1<T>>
88where
89 T: Float + Debug + Copy + 'static + SimdUnifiedOps + Send + Sync,
90{
91 let (rows, cols) = matrix.shape();
92
93 if x.len() != cols {
94 return Err(SparseError::DimensionMismatch {
95 expected: cols,
96 found: x.len(),
97 });
98 }
99
100 let mut y = Array1::zeros(rows);
101
102 let (_row_indices, col_indices, values) = matrix.find();
104 let row_ptr = matrix.get_indptr();
105
106 if options.use_parallel && rows >= options.parallel_threshold {
108 let chunk_size = rows.div_ceil(4); let row_chunks: Vec<_> = (0..rows)
111 .collect::<Vec<_>>()
112 .chunks(chunk_size)
113 .map(|chunk| chunk.to_vec())
114 .collect();
115
116 let results: Vec<_> = parallel_map(&row_chunks, |row_chunk| {
117 let mut local_y = vec![T::zero(); row_chunk.len()];
118
119 for (local_idx, &i) in row_chunk.iter().enumerate() {
120 let start = row_ptr[i];
121 let end = row_ptr[i + 1];
122 let row_length = end - start;
123
124 if row_length >= options.min_simd_size {
125 let mut sum = T::zero();
127 let mut j = start;
128
129 while j + options.chunk_size <= end {
131 let mut values_chunk = vec![T::zero(); options.chunk_size];
133 let mut x_vals_chunk = vec![T::zero(); options.chunk_size];
134
135 for (idx, k) in (j..j + options.chunk_size).enumerate() {
136 values_chunk[idx] = values[k];
137 x_vals_chunk[idx] = x[col_indices[k]];
138 }
139
140 let values_view = ArrayView1::from(&values_chunk);
142 let x_vals_view = ArrayView1::from(&x_vals_chunk);
143 let dot_product = T::simd_dot(&values_view, &x_vals_view);
144 sum = sum + dot_product;
145 j += options.chunk_size;
146 }
147
148 for k in j..end {
150 sum = sum + values[k] * x[col_indices[k]];
151 }
152
153 local_y[local_idx] = sum;
154 } else {
155 let mut sum = T::zero();
157 for k in start..end {
158 sum = sum + values[k] * x[col_indices[k]];
159 }
160 local_y[local_idx] = sum;
161 }
162 }
163
164 (row_chunk.clone(), local_y)
165 });
166
167 for (row_chunk, local_y) in results {
169 for (local_idx, &global_idx) in row_chunk.iter().enumerate() {
170 y[global_idx] = local_y[local_idx];
171 }
172 }
173 } else {
174 for i in 0..rows {
176 let start = row_ptr[i];
177 let end = row_ptr[i + 1];
178 let row_length = end - start;
179
180 if row_length >= options.min_simd_size {
181 let mut sum = T::zero();
183 let mut j = start;
184
185 while j + options.chunk_size <= end {
187 let mut values_chunk = vec![T::zero(); options.chunk_size];
189 let mut x_vals_chunk = vec![T::zero(); options.chunk_size];
190
191 for (idx, k) in (j..j + options.chunk_size).enumerate() {
192 values_chunk[idx] = values[k];
193 x_vals_chunk[idx] = x[col_indices[k]];
194 }
195
196 let values_view = ArrayView1::from(&values_chunk);
198 let x_vals_view = ArrayView1::from(&x_vals_chunk);
199 let chunk_sum = T::simd_dot(&values_view, &x_vals_view);
200 sum = sum + chunk_sum;
201 j += options.chunk_size;
202 }
203
204 for k in j..end {
206 sum = sum + values[k] * x[col_indices[k]];
207 }
208
209 y[i] = sum;
210 } else {
211 let mut sum = T::zero();
213 for k in start..end {
214 sum = sum + values[k] * x[col_indices[k]];
215 }
216 y[i] = sum;
217 }
218 }
219 }
220
221 Ok(y)
222}
223
224#[derive(Debug, Clone, Copy)]
226pub enum ElementwiseOp {
227 Add,
229 Sub,
231 Mul,
233 Div,
235}
236
237#[allow(dead_code)]
250pub fn simd_sparse_elementwise<T, S1, S2>(
251 a: &S1,
252 b: &S2,
253 op: ElementwiseOp,
254 options: Option<SimdOptions>,
255) -> SparseResult<CsrArray<T>>
256where
257 T: Float + Debug + Copy + 'static + SimdUnifiedOps + Send + Sync,
258 S1: SparseArray<T>,
259 S2: SparseArray<T>,
260{
261 if a.shape() != b.shape() {
262 return Err(SparseError::DimensionMismatch {
263 expected: a.shape().0 * a.shape().1,
264 found: b.shape().0 * b.shape().1,
265 });
266 }
267
268 let opts = options.unwrap_or_default();
269
270 let a_csr = a.to_csr()?;
272 let b_csr = b.to_csr()?;
273
274 let (_, _, a_values) = a_csr.find();
276 let (_, _, b_values) = b_csr.find();
277
278 if a_values.len() >= opts.min_simd_size && b_values.len() >= opts.min_simd_size {
282 let result = match op {
284 ElementwiseOp::Add => {
285 if let (Some(a_csr_concrete), Some(b_csr_concrete)) = (
287 a_csr.as_any().downcast_ref::<CsrArray<T>>(),
288 b_csr.as_any().downcast_ref::<CsrArray<T>>(),
289 ) {
290 simd_sparse_binary_op(a_csr_concrete, b_csr_concrete, &opts, |x, y| x + y)?
291 } else {
292 return a_csr.add(&*b_csr).and_then(|boxed| {
294 boxed
295 .as_any()
296 .downcast_ref::<CsrArray<T>>()
297 .cloned()
298 .ok_or_else(|| {
299 SparseError::ValueError(
300 "Failed to convert result to CsrArray".to_string(),
301 )
302 })
303 });
304 }
305 }
306 ElementwiseOp::Sub => {
307 if let (Some(a_csr_concrete), Some(b_csr_concrete)) = (
308 a_csr.as_any().downcast_ref::<CsrArray<T>>(),
309 b_csr.as_any().downcast_ref::<CsrArray<T>>(),
310 ) {
311 simd_sparse_binary_op(a_csr_concrete, b_csr_concrete, &opts, |x, y| x - y)?
312 } else {
313 return a_csr.sub(&*b_csr).and_then(|boxed| {
314 boxed
315 .as_any()
316 .downcast_ref::<CsrArray<T>>()
317 .cloned()
318 .ok_or_else(|| {
319 SparseError::ValueError(
320 "Failed to convert result to CsrArray".to_string(),
321 )
322 })
323 });
324 }
325 }
326 ElementwiseOp::Mul => {
327 if let (Some(a_csr_concrete), Some(b_csr_concrete)) = (
328 a_csr.as_any().downcast_ref::<CsrArray<T>>(),
329 b_csr.as_any().downcast_ref::<CsrArray<T>>(),
330 ) {
331 simd_sparse_binary_op(a_csr_concrete, b_csr_concrete, &opts, |x, y| x * y)?
332 } else {
333 return a_csr.mul(&*b_csr).and_then(|boxed| {
334 boxed
335 .as_any()
336 .downcast_ref::<CsrArray<T>>()
337 .cloned()
338 .ok_or_else(|| {
339 SparseError::ValueError(
340 "Failed to convert result to CsrArray".to_string(),
341 )
342 })
343 });
344 }
345 }
346 ElementwiseOp::Div => {
347 if let (Some(a_csr_concrete), Some(b_csr_concrete)) = (
348 a_csr.as_any().downcast_ref::<CsrArray<T>>(),
349 b_csr.as_any().downcast_ref::<CsrArray<T>>(),
350 ) {
351 simd_sparse_binary_op(a_csr_concrete, b_csr_concrete, &opts, |x, y| x / y)?
352 } else {
353 return a_csr.div(&*b_csr).and_then(|boxed| {
354 boxed
355 .as_any()
356 .downcast_ref::<CsrArray<T>>()
357 .cloned()
358 .ok_or_else(|| {
359 SparseError::ValueError(
360 "Failed to convert result to CsrArray".to_string(),
361 )
362 })
363 });
364 }
365 }
366 };
367 Ok(result)
368 } else {
369 let result_box = match op {
371 ElementwiseOp::Add => a_csr.add(&*b_csr)?,
372 ElementwiseOp::Sub => a_csr.sub(&*b_csr)?,
373 ElementwiseOp::Mul => a_csr.mul(&*b_csr)?,
374 ElementwiseOp::Div => a_csr.div(&*b_csr)?,
375 };
376
377 result_box
379 .as_any()
380 .downcast_ref::<CsrArray<T>>()
381 .cloned()
382 .ok_or_else(|| {
383 SparseError::ValueError("Failed to convert result to CsrArray".to_string())
384 })
385 }
386}
387
388#[allow(dead_code)]
390fn simd_sparse_binary_op<T, F>(
391 a: &CsrArray<T>,
392 b: &CsrArray<T>,
393 options: &SimdOptions,
394 op: F,
395) -> SparseResult<CsrArray<T>>
396where
397 T: Float + Debug + Copy + 'static + SimdUnifiedOps + Send + Sync,
398 F: Fn(T, T) -> T + Send + Sync + Copy,
399{
400 let (rows, cols) = a.shape();
401 let mut result_rows = Vec::new();
402 let mut result_cols = Vec::new();
403 let mut result_values = Vec::new();
404
405 let (a_row_indices, a_col_indices, a_values) = a.find();
407 let (b_row_indices, b_col_indices, b_values) = b.find();
408
409 use std::collections::HashMap;
411 let mut a_map = HashMap::new();
412 let mut b_map = HashMap::new();
413
414 for (i, (&row, &col)) in a_row_indices.iter().zip(a_col_indices.iter()).enumerate() {
415 a_map.insert((row, col), a_values[i]);
416 }
417
418 for (i, (&row, &col)) in b_row_indices.iter().zip(b_col_indices.iter()).enumerate() {
419 b_map.insert((row, col), b_values[i]);
420 }
421
422 let mut all_positions = std::collections::BTreeSet::new();
424 for &pos in a_map.keys() {
425 all_positions.insert(pos);
426 }
427 for &pos in b_map.keys() {
428 all_positions.insert(pos);
429 }
430
431 let positions: Vec<_> = all_positions.into_iter().collect();
433
434 if options.use_parallel && positions.len() >= options.parallel_threshold {
435 let chunks: Vec<_> = positions.chunks(options.chunk_size).collect();
437 let results: Vec<_> = parallel_map(&chunks, |chunk| {
438 let mut local_rows = Vec::new();
439 let mut local_cols = Vec::new();
440 let mut local_values = Vec::new();
441
442 for &(row, col) in *chunk {
443 let a_val = a_map.get(&(row, col)).copied().unwrap_or(T::zero());
444 let b_val = b_map.get(&(row, col)).copied().unwrap_or(T::zero());
445 let result_val = op(a_val, b_val);
446
447 if !result_val.is_zero() {
448 local_rows.push(row);
449 local_cols.push(col);
450 local_values.push(result_val);
451 }
452 }
453
454 (local_rows, local_cols, local_values)
455 });
456
457 for (mut local_rows, mut local_cols, mut local_values) in results {
459 result_rows.append(&mut local_rows);
460 result_cols.append(&mut local_cols);
461 result_values.append(&mut local_values);
462 }
463 } else {
464 for (row, col) in positions {
466 let a_val = a_map.get(&(row, col)).copied().unwrap_or(T::zero());
467 let b_val = b_map.get(&(row, col)).copied().unwrap_or(T::zero());
468 let result_val = op(a_val, b_val);
469
470 if !result_val.is_zero() {
471 result_rows.push(row);
472 result_cols.push(col);
473 result_values.push(result_val);
474 }
475 }
476 }
477
478 CsrArray::from_triplets(
479 &result_rows,
480 &result_cols,
481 &result_values,
482 (rows, cols),
483 false,
484 )
485}
486
487#[allow(dead_code)]
498pub fn simd_sparse_transpose<T, S>(
499 matrix: &S,
500 options: Option<SimdOptions>,
501) -> SparseResult<CsrArray<T>>
502where
503 T: Float + Debug + Copy + 'static + SimdUnifiedOps + Send + Sync,
504 S: SparseArray<T>,
505{
506 let opts = options.unwrap_or_default();
507 let (rows, cols) = matrix.shape();
508 let (row_indices, col_indices, values) = matrix.find();
509
510 if opts.use_parallel && values.len() >= opts.parallel_threshold {
511 let chunks: Vec<_> = (0..values.len())
513 .collect::<Vec<_>>()
514 .chunks(opts.chunk_size)
515 .map(|chunk| chunk.to_vec())
516 .collect();
517
518 let transposed_triplets: Vec<_> = parallel_map(&chunks, |chunk| {
519 let mut local_rows = Vec::new();
520 let mut local_cols = Vec::new();
521 let mut local_values = Vec::new();
522
523 for &idx in chunk {
524 local_rows.push(col_indices[idx]);
525 local_cols.push(row_indices[idx]);
526 local_values.push(values[idx]);
527 }
528
529 (local_rows, local_cols, local_values)
530 });
531
532 let mut result_rows = Vec::new();
534 let mut result_cols = Vec::new();
535 let mut result_values = Vec::new();
536
537 for (mut local_rows, mut local_cols, mut local_values) in transposed_triplets {
538 result_rows.append(&mut local_rows);
539 result_cols.append(&mut local_cols);
540 result_values.append(&mut local_values);
541 }
542
543 CsrArray::from_triplets(
544 &result_rows,
545 &result_cols,
546 &result_values,
547 (cols, rows),
548 false,
549 )
550 } else {
551 CsrArray::from_triplets(
553 col_indices.as_slice().expect("Array should be contiguous"),
554 row_indices.as_slice().expect("Array should be contiguous"),
555 values.as_slice().expect("Array should be contiguous"),
556 (cols, rows),
557 false,
558 )
559 }
560}
561
562#[allow(dead_code)]
574pub fn simd_sparse_matmul<T, S1, S2>(
575 a: &S1,
576 b: &S2,
577 options: Option<SimdOptions>,
578) -> SparseResult<CsrArray<T>>
579where
580 T: Float + Debug + Copy + 'static + SimdUnifiedOps + Send + Sync,
581 S1: SparseArray<T>,
582 S2: SparseArray<T>,
583{
584 if a.shape().1 != b.shape().0 {
585 return Err(SparseError::DimensionMismatch {
586 expected: a.shape().1,
587 found: b.shape().0,
588 });
589 }
590
591 let opts = options.unwrap_or_default();
592
593 let a_csr = a.to_csr()?;
595 let b_csc = b.to_csc()?; let (a_rows, a_cols) = a_csr.shape();
598 let (_b_rows, b_cols) = b_csc.shape();
599
600 let mut result_rows = Vec::new();
602 let mut result_cols = Vec::new();
603 let mut result_values = Vec::new();
604
605 let a_indptr = if let Some(a_concrete) = a_csr.as_any().downcast_ref::<CsrArray<T>>() {
607 a_concrete.get_indptr() } else {
609 return Err(SparseError::ValueError(
610 "Matrix A must be CSR format".to_string(),
611 ));
612 };
613 let (_, a_col_indices, a_values) = a_csr.find();
614
615 let b_indptr = if let Some(b_concrete) = b_csc.as_any().downcast_ref::<CscArray<T>>() {
616 b_concrete.get_indptr() } else if let Some(b_concrete) = b_csc.as_any().downcast_ref::<CsrArray<T>>() {
618 b_concrete.get_indptr()
621 } else {
622 return Err(SparseError::ValueError(
623 "Matrix B must be CSC or CSR format".to_string(),
624 ));
625 };
626 let (_, b_row_indices, b_values) = b_csc.find();
627
628 if opts.use_parallel && a_rows >= opts.parallel_threshold {
629 let chunks: Vec<_> = (0..a_rows)
631 .collect::<Vec<_>>()
632 .chunks(opts.chunk_size)
633 .map(|chunk| chunk.to_vec())
634 .collect();
635 let results: Vec<_> = parallel_map(&chunks, |row_chunk| {
636 let mut local_rows = Vec::new();
637 let mut local_cols = Vec::new();
638 let mut local_values = Vec::new();
639
640 for &i in row_chunk {
641 let a_start = a_indptr[i];
642 let a_end = a_indptr[i + 1];
643
644 for j in 0..b_cols {
646 let b_start = b_indptr[j];
647 let b_end = b_indptr[j + 1];
648
649 let mut sum = T::zero();
651 let mut a_idx = a_start;
652 let mut b_idx = b_start;
653
654 if (a_end - a_start) >= opts.min_simd_size
656 && (b_end - b_start) >= opts.min_simd_size
657 {
658 while a_idx < a_end && b_idx < b_end {
660 let a_col = a_col_indices[a_idx];
661 let b_row = b_row_indices[b_idx];
662
663 match a_col.cmp(&b_row) {
664 std::cmp::Ordering::Equal => {
665 sum = sum + a_values[a_idx] * b_values[b_idx];
666 a_idx += 1;
667 b_idx += 1;
668 }
669 std::cmp::Ordering::Less => {
670 a_idx += 1;
671 }
672 std::cmp::Ordering::Greater => {
673 b_idx += 1;
674 }
675 }
676 }
677 } else {
678 while a_idx < a_end && b_idx < b_end {
680 let a_col = a_col_indices[a_idx];
681 let b_row = b_row_indices[b_idx];
682
683 match a_col.cmp(&b_row) {
684 std::cmp::Ordering::Equal => {
685 sum = sum + a_values[a_idx] * b_values[b_idx];
686 a_idx += 1;
687 b_idx += 1;
688 }
689 std::cmp::Ordering::Less => {
690 a_idx += 1;
691 }
692 std::cmp::Ordering::Greater => {
693 b_idx += 1;
694 }
695 }
696 }
697 }
698
699 if !sum.is_zero() {
700 local_rows.push(i);
701 local_cols.push(j);
702 local_values.push(sum);
703 }
704 }
705 }
706
707 (local_rows, local_cols, local_values)
708 });
709
710 for (mut local_rows, mut local_cols, mut local_values) in results {
712 result_rows.append(&mut local_rows);
713 result_cols.append(&mut local_cols);
714 result_values.append(&mut local_values);
715 }
716 } else {
717 for i in 0..a_rows {
719 let a_start = a_indptr[i];
720 let a_end = a_indptr[i + 1];
721
722 for j in 0..b_cols {
723 let b_start = b_indptr[j];
724 let b_end = b_indptr[j + 1];
725
726 let mut sum = T::zero();
728 let mut a_idx = a_start;
729 let mut b_idx = b_start;
730
731 while a_idx < a_end && b_idx < b_end {
732 let a_col = a_col_indices[a_idx];
733 let b_row = b_row_indices[b_idx];
734
735 match a_col.cmp(&b_row) {
736 std::cmp::Ordering::Equal => {
737 sum = sum + a_values[a_idx] * b_values[b_idx];
738 a_idx += 1;
739 b_idx += 1;
740 }
741 std::cmp::Ordering::Less => {
742 a_idx += 1;
743 }
744 std::cmp::Ordering::Greater => {
745 b_idx += 1;
746 }
747 }
748 }
749
750 if !sum.is_zero() {
751 result_rows.push(i);
752 result_cols.push(j);
753 result_values.push(sum);
754 }
755 }
756 }
757 }
758
759 CsrArray::from_triplets(
760 &result_rows,
761 &result_cols,
762 &result_values,
763 (a_rows, b_cols),
764 false,
765 )
766}
767
768#[allow(dead_code)]
782pub fn simd_sparse_norm<T, S>(
783 matrix: &S,
784 norm_type: &str,
785 options: Option<SimdOptions>,
786) -> SparseResult<T>
787where
788 T: Float + Debug + Copy + 'static + SimdUnifiedOps + Send + Sync,
789 S: SparseArray<T>,
790{
791 let opts = options.unwrap_or_default();
792 let (_, _, values) = matrix.find();
793
794 match norm_type {
795 "fro" | "frobenius" => {
796 if opts.use_parallel && values.len() >= opts.parallel_threshold {
798 let chunks: Vec<_> = values
799 .as_slice()
800 .expect("Array should be contiguous")
801 .chunks(opts.chunk_size)
802 .collect();
803 let partial_sums: Vec<T> = parallel_map(&chunks, |chunk| {
804 let chunk_view = ArrayView1::from(chunk);
805 T::simd_dot(&chunk_view, &chunk_view)
806 });
807 Ok(partial_sums
808 .iter()
809 .copied()
810 .fold(T::zero(), |acc, x| acc + x)
811 .sqrt())
812 } else {
813 let values_view = values.view();
814 let sum_squares = T::simd_dot(&values_view, &values_view);
815 Ok(sum_squares.sqrt())
816 }
817 }
818 "1" => {
819 let (_rows, cols) = matrix.shape();
821 let (_row_indices, col_indices, values) = matrix.find();
822
823 let mut column_sums = vec![T::zero(); cols];
824
825 if opts.use_parallel && values.len() >= opts.parallel_threshold {
826 let chunks: Vec<_> = (0..values.len())
827 .collect::<Vec<_>>()
828 .chunks(opts.chunk_size)
829 .map(|chunk| chunk.to_vec())
830 .collect();
831 let partial_sums: Vec<Vec<T>> = parallel_map(&chunks, |chunk| {
832 let mut local_sums = vec![T::zero(); cols];
833 for &idx in chunk {
834 let col = col_indices[idx];
835 let val = values[idx].abs();
836 local_sums[col] = local_sums[col] + val;
837 }
838 local_sums
839 });
840
841 for partial_sum in partial_sums {
843 for j in 0..cols {
844 column_sums[j] = column_sums[j] + partial_sum[j];
845 }
846 }
847 } else {
848 for (i, &col) in col_indices.iter().enumerate() {
849 column_sums[col] = column_sums[col] + values[i].abs();
850 }
851 }
852
853 Ok(column_sums
854 .iter()
855 .copied()
856 .fold(T::zero(), |acc, x| if x > acc { x } else { acc }))
857 }
858 "inf" | "infinity" => {
859 let (rows, cols) = matrix.shape();
861 let (row_indices, col_indices, values) = matrix.find();
862
863 let mut row_sums = vec![T::zero(); rows];
864
865 if opts.use_parallel && values.len() >= opts.parallel_threshold {
866 let chunks: Vec<_> = (0..values.len())
867 .collect::<Vec<_>>()
868 .chunks(opts.chunk_size)
869 .map(|chunk| chunk.to_vec())
870 .collect();
871 let partial_sums: Vec<Vec<T>> = parallel_map(&chunks, |chunk| {
872 let mut local_sums = vec![T::zero(); rows];
873 for &idx in chunk {
874 let row = row_indices[idx];
875 let val = values[idx].abs();
876 local_sums[row] = local_sums[row] + val;
877 }
878 local_sums
879 });
880
881 for partial_sum in partial_sums {
883 for i in 0..rows {
884 row_sums[i] = row_sums[i] + partial_sum[i];
885 }
886 }
887 } else {
888 for (i, &row) in row_indices.iter().enumerate() {
889 row_sums[row] = row_sums[row] + values[i].abs();
890 }
891 }
892
893 Ok(row_sums
894 .iter()
895 .copied()
896 .fold(T::zero(), |acc, x| if x > acc { x } else { acc }))
897 }
898 _ => Err(SparseError::ValueError(format!(
899 "Unknown norm _type: {norm_type}"
900 ))),
901 }
902}
903
904#[allow(dead_code)]
918pub fn simd_sparse_scale<T, S>(
919 matrix: &S,
920 scalar: T,
921 options: Option<SimdOptions>,
922) -> SparseResult<CsrArray<T>>
923where
924 T: Float + Debug + Copy + 'static + SimdUnifiedOps + Send + Sync,
925 S: SparseArray<T>,
926{
927 let opts = options.unwrap_or_default();
928 let (rows, cols) = matrix.shape();
929 let (row_indices, col_indices, values) = matrix.find();
930
931 let scaled_values = if opts.use_parallel && values.len() >= opts.parallel_threshold {
932 let chunks: Vec<_> = values
934 .as_slice()
935 .expect("Array should be contiguous")
936 .chunks(opts.chunk_size)
937 .collect();
938 let scaled_chunks: Vec<Vec<T>> = parallel_map(&chunks, |chunk: &&[T]| {
939 let _scalar_vec = vec![scalar; chunk.len()];
940 let mut result = vec![T::zero(); chunk.len()];
941
942 for i in 0..chunk.len() {
944 result[i] = chunk[i] * scalar;
945 }
946 result
947 });
948
949 scaled_chunks.into_iter().flatten().collect()
951 } else {
952 values.iter().map(|&val| val * scalar).collect::<Vec<T>>()
954 };
955
956 CsrArray::from_triplets(
957 row_indices.as_slice().expect("Array should be contiguous"),
958 col_indices.as_slice().expect("Array should be contiguous"),
959 &scaled_values,
960 (rows, cols),
961 false,
962 )
963}
964
965#[allow(dead_code)]
979pub fn simd_sparse_linear_combination<T, S>(
980 matrices: &[&S],
981 coefficients: &[T],
982 options: Option<SimdOptions>,
983) -> SparseResult<CsrArray<T>>
984where
985 T: Float + Debug + Copy + 'static + SimdUnifiedOps + Send + Sync,
986 S: SparseArray<T> + Sync,
987{
988 if matrices.is_empty() {
989 return Err(SparseError::ValueError("No matrices provided".to_string()));
990 }
991
992 if matrices.len() != coefficients.len() {
993 return Err(SparseError::DimensionMismatch {
994 expected: matrices.len(),
995 found: coefficients.len(),
996 });
997 }
998
999 let opts = options.unwrap_or_default();
1000 let (rows, cols) = matrices[0].shape();
1001
1002 for matrix in matrices.iter() {
1004 if matrix.shape() != (rows, cols) {
1005 return Err(SparseError::DimensionMismatch {
1006 expected: rows * cols,
1007 found: matrix.shape().0 * matrix.shape().1,
1008 });
1009 }
1010 }
1011
1012 use std::collections::HashMap;
1014 let mut accumulator = HashMap::new();
1015
1016 if opts.use_parallel && matrices.len() >= 4 {
1017 let results: Vec<HashMap<(usize, usize), T>> = parallel_map(matrices, |matrix| {
1019 let mut local_accumulator = HashMap::new();
1020 let (row_indices, col_indices, values) = matrix.find();
1021
1022 for (k, (&i, &j)) in row_indices.iter().zip(col_indices.iter()).enumerate() {
1023 let entry = local_accumulator.entry((i, j)).or_insert(T::zero());
1024 *entry = *entry + values[k];
1025 }
1026
1027 local_accumulator
1028 });
1029
1030 for (idx, local_acc) in results.into_iter().enumerate() {
1032 let coeff = coefficients[idx];
1033 for ((i, j), val) in local_acc {
1034 let entry = accumulator.entry((i, j)).or_insert(T::zero());
1035 *entry = *entry + coeff * val;
1036 }
1037 }
1038 } else {
1039 for (idx, matrix) in matrices.iter().enumerate() {
1041 let coeff = coefficients[idx];
1042 let (row_indices, col_indices, values) = matrix.find();
1043
1044 for (k, (&i, &j)) in row_indices.iter().zip(col_indices.iter()).enumerate() {
1045 let entry = accumulator.entry((i, j)).or_insert(T::zero());
1046 *entry = *entry + coeff * values[k];
1047 }
1048 }
1049 }
1050
1051 let mut result_rows = Vec::new();
1053 let mut result_cols = Vec::new();
1054 let mut result_values = Vec::new();
1055
1056 for ((i, j), val) in accumulator {
1057 if !val.is_zero() {
1058 result_rows.push(i);
1059 result_cols.push(j);
1060 result_values.push(val);
1061 }
1062 }
1063
1064 CsrArray::from_triplets(
1065 &result_rows,
1066 &result_cols,
1067 &result_values,
1068 (rows, cols),
1069 false,
1070 )
1071}
1072
1073#[allow(dead_code)]
1075pub fn simd_sparse_matmul_default<T, S1, S2>(a: &S1, b: &S2) -> SparseResult<CsrArray<T>>
1076where
1077 T: Float + Debug + Copy + 'static + SimdUnifiedOps + Send + Sync,
1078 S1: SparseArray<T>,
1079 S2: SparseArray<T>,
1080{
1081 simd_sparse_matmul(a, b, None)
1082}
1083
1084#[cfg(test)]
1085mod tests {
1086 use super::*;
1087 use crate::csr_array::CsrArray;
1088 use approx::assert_relative_eq;
1089
1090 #[test]
1091 fn test_simd_csr_matvec() {
1092 let rows = vec![0, 0, 1, 2, 2];
1093 let cols = vec![0, 2, 1, 0, 2];
1094 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1095 let matrix = CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).unwrap();
1096
1097 let x = Array1::from_vec(vec![1.0, 2.0, 3.0]);
1098 let y = simd_csr_matvec(&matrix, &x.view(), SimdOptions::default()).unwrap();
1099
1100 assert_eq!(y.len(), 3);
1102 assert_relative_eq!(y[0], 7.0);
1103 assert_relative_eq!(y[1], 6.0);
1104 assert_relative_eq!(y[2], 19.0);
1105 }
1106
1107 #[test]
1108 fn test_simd_sparse_elementwise() {
1109 let rows = vec![0, 1, 2];
1110 let cols = vec![0, 1, 2];
1111 let data1 = vec![1.0, 2.0, 3.0];
1112 let data2 = vec![4.0, 5.0, 6.0];
1113
1114 let a = CsrArray::from_triplets(&rows, &cols, &data1, (3, 3), false).unwrap();
1115 let b = CsrArray::from_triplets(&rows, &cols, &data2, (3, 3), false).unwrap();
1116
1117 let result = simd_sparse_elementwise(&a, &b, ElementwiseOp::Add, None).unwrap();
1118
1119 assert_relative_eq!(result.get(0, 0), 5.0);
1121 assert_relative_eq!(result.get(1, 1), 7.0);
1122 assert_relative_eq!(result.get(2, 2), 9.0);
1123 }
1124
1125 #[test]
1126 fn test_simd_sparse_matmul() {
1127 let rows = vec![0, 1];
1129 let cols = vec![0, 1];
1130 let data1 = vec![2.0, 3.0];
1131 let data2 = vec![4.0, 5.0];
1132
1133 let a = CsrArray::from_triplets(&rows, &cols, &data1, (2, 2), false).unwrap();
1134 let b = CsrArray::from_triplets(&rows, &cols, &data2, (2, 2), false).unwrap();
1135
1136 let result = simd_sparse_matmul_default(&a, &b).unwrap();
1137
1138 assert_relative_eq!(result.get(0, 0), 8.0);
1140 assert_relative_eq!(result.get(1, 1), 15.0);
1141 assert_relative_eq!(result.get(0, 1), 0.0);
1142 assert_relative_eq!(result.get(1, 0), 0.0);
1143 }
1144}