Skip to main content

scirs2_sparse/
csr.rs

1//! Compressed Sparse Row (CSR) matrix format
2//!
3//! This module provides the CSR matrix format implementation, which is
4//! efficient for row operations, matrix-vector multiplication, and more.
5
6use crate::error::{SparseError, SparseResult};
7use scirs2_core::numeric::{Float, SparseElement, Zero};
8use scirs2_core::GpuDataType;
9use std::cmp::PartialEq;
10
11/// Compressed Sparse Row (CSR) matrix
12///
13/// A sparse matrix format that compresses rows, making it efficient for
14/// row operations and matrix-vector multiplication.
15#[derive(Clone, Debug)]
16pub struct CsrMatrix<T> {
17    /// Number of rows
18    rows: usize,
19    /// Number of columns
20    cols: usize,
21    /// Row pointers (size rows+1)
22    pub indptr: Vec<usize>,
23    /// Column indices
24    pub indices: Vec<usize>,
25    /// Data values
26    pub data: Vec<T>,
27}
28
29impl<T> CsrMatrix<T>
30where
31    T: Clone + Copy + Zero + PartialEq + SparseElement,
32{
33    /// Get the value at the specified position
34    pub fn get(&self, row: usize, col: usize) -> T {
35        // Check bounds
36        if row >= self.rows || col >= self.cols {
37            return T::sparse_zero();
38        }
39
40        // Find the element in the CSR format
41        for j in self.indptr[row]..self.indptr[row + 1] {
42            if self.indices[j] == col {
43                return self.data[j];
44            }
45        }
46
47        // Element not found, return zero
48        T::sparse_zero()
49    }
50
51    /// Get the triplets (row indices, column indices, data)
52    pub fn get_triplets(&self) -> (Vec<usize>, Vec<usize>, Vec<T>) {
53        let mut rows = Vec::new();
54        let mut cols = Vec::new();
55        let mut values = Vec::new();
56
57        for i in 0..self.rows {
58            for j in self.indptr[i]..self.indptr[i + 1] {
59                rows.push(i);
60                cols.push(self.indices[j]);
61                values.push(self.data[j]);
62            }
63        }
64
65        (rows, cols, values)
66    }
67    /// Create a new CSR matrix from raw data
68    ///
69    /// # Arguments
70    ///
71    /// * `data` - Vector of non-zero values
72    /// * `rowindices` - Vector of row indices for each non-zero value
73    /// * `colindices` - Vector of column indices for each non-zero value
74    /// * `shape` - Tuple containing the matrix dimensions (rows, cols)
75    ///
76    /// # Returns
77    ///
78    /// * A new CSR matrix
79    ///
80    /// # Examples
81    ///
82    /// ```
83    /// use scirs2_sparse::csr::CsrMatrix;
84    ///
85    /// // Create a 3x3 sparse matrix with 5 non-zero elements
86    /// let rows = vec![0, 0, 1, 2, 2];
87    /// let cols = vec![0, 2, 2, 0, 1];
88    /// let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
89    /// let shape = (3, 3);
90    ///
91    /// let matrix = CsrMatrix::new(data.clone(), rows, cols, shape).expect("Operation failed");
92    /// ```
93    pub fn new(
94        data: Vec<T>,
95        rowindices: Vec<usize>,
96        colindices: Vec<usize>,
97        shape: (usize, usize),
98    ) -> SparseResult<Self> {
99        // Validate input data
100        if data.len() != rowindices.len() || data.len() != colindices.len() {
101            return Err(SparseError::DimensionMismatch {
102                expected: data.len(),
103                found: std::cmp::min(rowindices.len(), colindices.len()),
104            });
105        }
106
107        let (rows, cols) = shape;
108
109        // Check indices are within bounds
110        if rowindices.iter().any(|&i| i >= rows) {
111            return Err(SparseError::ValueError(
112                "Row index out of bounds".to_string(),
113            ));
114        }
115
116        if colindices.iter().any(|&i| i >= cols) {
117            return Err(SparseError::ValueError(
118                "Column index out of bounds".to_string(),
119            ));
120        }
121
122        // Convert triplet format to CSR
123        // First, sort by row, then by column
124        let mut triplets: Vec<(usize, usize, T)> = rowindices
125            .into_iter()
126            .zip(colindices)
127            .zip(data)
128            .map(|((r, c), v)| (r, c, v))
129            .collect();
130        triplets.sort_by_key(|&(r, c_, _)| (r, c_));
131
132        // Create indptr, indices, and data arrays
133        let nnz = triplets.len();
134        let mut indptr = vec![0; rows + 1];
135        let mut indices = Vec::with_capacity(nnz);
136        let mut data_out = Vec::with_capacity(nnz);
137
138        // Count elements per row to build indptr
139        for &(r_, _, _) in &triplets {
140            indptr[r_ + 1] += 1;
141        }
142
143        // Compute cumulative sum for indptr
144        for i in 1..=rows {
145            indptr[i] += indptr[i - 1];
146        }
147
148        // Fill indices and data
149        for (_r, c, v) in triplets {
150            indices.push(c);
151            data_out.push(v);
152        }
153
154        Ok(CsrMatrix {
155            rows,
156            cols,
157            indptr,
158            indices,
159            data: data_out,
160        })
161    }
162
163    /// Create a CSR matrix from triplet format (COO-like construction)
164    ///
165    /// This is a convenience constructor that builds a CSR matrix from
166    /// separate row indices, column indices, and values vectors.
167    ///
168    /// # Arguments
169    ///
170    /// * `nrows` - Number of rows in the matrix
171    /// * `ncols` - Number of columns in the matrix
172    /// * `row_indices` - Vector of row indices for each non-zero value
173    /// * `col_indices` - Vector of column indices for each non-zero value
174    /// * `values` - Vector of non-zero values
175    ///
176    /// # Returns
177    ///
178    /// * `Ok(CsrMatrix)` - A new CSR matrix
179    /// * `Err(SparseError)` - If input is invalid
180    ///
181    /// # Examples
182    ///
183    /// ```
184    /// use scirs2_sparse::csr::CsrMatrix;
185    ///
186    /// let row_indices = vec![0, 0, 1, 2, 2];
187    /// let col_indices = vec![0, 2, 2, 0, 1];
188    /// let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
189    ///
190    /// let matrix = CsrMatrix::from_triplets(3, 3, row_indices, col_indices, values).expect("Operation failed");
191    /// assert_eq!(matrix.nnz(), 5);
192    /// ```
193    pub fn from_triplets(
194        nrows: usize,
195        ncols: usize,
196        row_indices: Vec<usize>,
197        col_indices: Vec<usize>,
198        values: Vec<T>,
199    ) -> SparseResult<Self> {
200        Self::new(values, row_indices, col_indices, (nrows, ncols))
201    }
202
203    /// Create a CSR matrix from triplet tuples
204    ///
205    /// This constructor accepts a slice of (row, col, value) tuples,
206    /// which is convenient for constructing matrices from coordinate lists.
207    ///
208    /// # Arguments
209    ///
210    /// * `nrows` - Number of rows in the matrix
211    /// * `ncols` - Number of columns in the matrix
212    /// * `triplets` - Slice of (row_index, col_index, value) tuples
213    ///
214    /// # Returns
215    ///
216    /// * `Ok(CsrMatrix)` - A new CSR matrix
217    /// * `Err(SparseError)` - If input is invalid
218    ///
219    /// # Examples
220    ///
221    /// ```
222    /// use scirs2_sparse::csr::CsrMatrix;
223    ///
224    /// let triplets = vec![
225    ///     (0, 0, 1.0),
226    ///     (0, 2, 2.0),
227    ///     (1, 2, 3.0),
228    ///     (2, 0, 4.0),
229    ///     (2, 1, 5.0),
230    /// ];
231    ///
232    /// let matrix = CsrMatrix::try_from_triplets(3, 3, &triplets).expect("Operation failed");
233    /// assert_eq!(matrix.nnz(), 5);
234    /// assert_eq!(matrix.get(0, 0), 1.0);
235    /// assert_eq!(matrix.get(2, 1), 5.0);
236    /// ```
237    pub fn try_from_triplets(
238        nrows: usize,
239        ncols: usize,
240        triplets: &[(usize, usize, T)],
241    ) -> SparseResult<Self> {
242        let mut row_indices = Vec::with_capacity(triplets.len());
243        let mut col_indices = Vec::with_capacity(triplets.len());
244        let mut values = Vec::with_capacity(triplets.len());
245
246        for &(r, c, v) in triplets {
247            row_indices.push(r);
248            col_indices.push(c);
249            values.push(v);
250        }
251
252        Self::from_triplets(nrows, ncols, row_indices, col_indices, values)
253    }
254
255    /// Create a new CSR matrix from raw CSR format
256    ///
257    /// # Arguments
258    ///
259    /// * `data` - Vector of non-zero values
260    /// * `indptr` - Vector of row pointers (size rows+1)
261    /// * `indices` - Vector of column indices
262    /// * `shape` - Tuple containing the matrix dimensions (rows, cols)
263    ///
264    /// # Returns
265    ///
266    /// * A new CSR matrix
267    pub fn from_raw_csr(
268        data: Vec<T>,
269        indptr: Vec<usize>,
270        indices: Vec<usize>,
271        shape: (usize, usize),
272    ) -> SparseResult<Self> {
273        let (rows, cols) = shape;
274
275        // Validate input data
276        if indptr.len() != rows + 1 {
277            return Err(SparseError::DimensionMismatch {
278                expected: rows + 1,
279                found: indptr.len(),
280            });
281        }
282
283        if data.len() != indices.len() {
284            return Err(SparseError::DimensionMismatch {
285                expected: data.len(),
286                found: indices.len(),
287            });
288        }
289
290        // Check if indptr is monotonically increasing
291        for i in 1..indptr.len() {
292            if indptr[i] < indptr[i - 1] {
293                return Err(SparseError::ValueError(
294                    "Row pointer array must be monotonically increasing".to_string(),
295                ));
296            }
297        }
298
299        // Check if the last indptr entry matches the data length
300        if indptr[rows] != data.len() {
301            return Err(SparseError::ValueError(
302                "Last row pointer entry must match data length".to_string(),
303            ));
304        }
305
306        // Check if indices are within bounds
307        if indices.iter().any(|&i| i >= cols) {
308            return Err(SparseError::ValueError(
309                "Column index out of bounds".to_string(),
310            ));
311        }
312
313        Ok(CsrMatrix {
314            rows,
315            cols,
316            indptr,
317            indices,
318            data,
319        })
320    }
321
322    /// Create a new empty CSR matrix
323    ///
324    /// # Arguments
325    ///
326    /// * `shape` - Tuple containing the matrix dimensions (rows, cols)
327    ///
328    /// # Returns
329    ///
330    /// * A new empty CSR matrix
331    pub fn empty(shape: (usize, usize)) -> Self {
332        let (rows, cols) = shape;
333        let indptr = vec![0; rows + 1];
334
335        CsrMatrix {
336            rows,
337            cols,
338            indptr,
339            indices: Vec::new(),
340            data: Vec::new(),
341        }
342    }
343
344    /// Get the number of rows in the matrix
345    pub fn rows(&self) -> usize {
346        self.rows
347    }
348
349    /// Get the number of columns in the matrix
350    pub fn cols(&self) -> usize {
351        self.cols
352    }
353
354    /// Get the shape (dimensions) of the matrix
355    pub fn shape(&self) -> (usize, usize) {
356        (self.rows, self.cols)
357    }
358
359    /// Get the number of non-zero elements in the matrix
360    pub fn nnz(&self) -> usize {
361        self.data.len()
362    }
363
364    /// Convert to dense matrix (as `Vec<Vec<T>>`)
365    pub fn to_dense(&self) -> Vec<Vec<T>>
366    where
367        T: Zero + Copy + SparseElement,
368    {
369        let mut result = vec![vec![T::sparse_zero(); self.cols]; self.rows];
370
371        for (row_idx, row) in result.iter_mut().enumerate() {
372            for j in self.indptr[row_idx]..self.indptr[row_idx + 1] {
373                let col_idx = self.indices[j];
374                row[col_idx] = self.data[j];
375            }
376        }
377
378        result
379    }
380
381    /// Transpose the matrix
382    pub fn transpose(&self) -> Self {
383        // Compute the number of non-zeros per column
384        let mut col_counts = vec![0; self.cols];
385        for &col in &self.indices {
386            col_counts[col] += 1;
387        }
388
389        // Compute column pointers (cumulative sum)
390        let mut col_ptrs = vec![0; self.cols + 1];
391        for i in 0..self.cols {
392            col_ptrs[i + 1] = col_ptrs[i] + col_counts[i];
393        }
394
395        // Fill the transposed matrix
396        let nnz = self.nnz();
397        let mut indices_t = vec![0; nnz];
398        let mut data_t = vec![T::sparse_zero(); nnz];
399        let mut col_counts = vec![0; self.cols];
400
401        for row in 0..self.rows {
402            for j in self.indptr[row]..self.indptr[row + 1] {
403                let col = self.indices[j];
404                let dest = col_ptrs[col] + col_counts[col];
405
406                indices_t[dest] = row;
407                data_t[dest] = self.data[j];
408                col_counts[col] += 1;
409            }
410        }
411
412        CsrMatrix {
413            rows: self.cols,
414            cols: self.rows,
415            indptr: col_ptrs,
416            indices: indices_t,
417            data: data_t,
418        }
419    }
420}
421
422impl<
423        T: Clone
424            + Copy
425            + std::ops::AddAssign
426            + std::ops::MulAssign
427            + std::cmp::PartialEq
428            + std::fmt::Debug
429            + scirs2_core::numeric::Zero
430            + std::ops::Add<Output = T>
431            + std::ops::Mul<Output = T>
432            + SparseElement,
433    > CsrMatrix<T>
434{
435    /// Check if matrix is symmetric
436    ///
437    /// # Returns
438    ///
439    /// * `true` if the matrix is symmetric, `false` otherwise
440    pub fn is_symmetric(&self) -> bool {
441        if self.rows != self.cols {
442            return false;
443        }
444
445        // Create a transposed matrix
446        let transposed = self.transpose();
447
448        // Compare the sparsity patterns and values
449        if self.nnz() != transposed.nnz() {
450            return false;
451        }
452
453        // Compare row by row
454        for row in 0..self.rows {
455            let self_start = self.indptr[row];
456            let self_end = self.indptr[row + 1];
457            let trans_start = transposed.indptr[row];
458            let trans_end = transposed.indptr[row + 1];
459
460            if self_end - self_start != trans_end - trans_start {
461                return false;
462            }
463
464            // Create sorted columns and values for this row
465            let mut self_entries: Vec<(usize, &T)> = (self_start..self_end)
466                .map(|j| (self.indices[j], &self.data[j]))
467                .collect();
468            self_entries.sort_by_key(|(col_, _)| *col_);
469
470            let mut trans_entries: Vec<(usize, &T)> = (trans_start..trans_end)
471                .map(|j| (transposed.indices[j], &transposed.data[j]))
472                .collect();
473            trans_entries.sort_by_key(|(col_, _)| *col_);
474
475            // Compare columns and values
476            for i in 0..self_entries.len() {
477                if self_entries[i].0 != trans_entries[i].0
478                    || self_entries[i].1 != trans_entries[i].1
479                {
480                    return false;
481                }
482            }
483        }
484
485        true
486    }
487
488    /// Matrix-matrix multiplication
489    ///
490    /// # Arguments
491    ///
492    /// * `other` - Matrix to multiply with
493    ///
494    /// # Returns
495    ///
496    /// * Result containing the product matrix
497    pub fn matmul(&self, other: &CsrMatrix<T>) -> SparseResult<CsrMatrix<T>> {
498        if self.cols != other.rows {
499            return Err(SparseError::DimensionMismatch {
500                expected: self.cols,
501                found: other.rows,
502            });
503        }
504
505        // For simplicity, we'll implement this using dense operations
506        // In a real implementation, you'd use a more efficient sparse algorithm
507        let a_dense = self.to_dense();
508        let b_dense = other.to_dense();
509
510        let m = self.rows;
511        let n = other.cols;
512        let k = self.cols;
513
514        let mut c_dense = vec![vec![T::sparse_zero(); n]; m];
515
516        for (i, c_row) in c_dense.iter_mut().enumerate().take(m) {
517            for (j, val) in c_row.iter_mut().enumerate().take(n) {
518                for (l, &a_val) in a_dense[i].iter().enumerate().take(k) {
519                    let prod = a_val * b_dense[l][j];
520                    *val += prod;
521                }
522            }
523        }
524
525        // Convert back to CSR format
526        let mut rowindices = Vec::new();
527        let mut colindices = Vec::new();
528        let mut values = Vec::new();
529
530        for (i, row) in c_dense.iter().enumerate() {
531            for (j, val) in row.iter().enumerate() {
532                if *val != T::sparse_zero() {
533                    rowindices.push(i);
534                    colindices.push(j);
535                    values.push(*val);
536                }
537            }
538        }
539
540        CsrMatrix::new(values, rowindices, colindices, (m, n))
541    }
542
543    /// Get row range for iterating over elements in a row
544    ///
545    /// # Arguments
546    ///
547    /// * `row` - Row index
548    ///
549    /// # Returns
550    ///
551    /// * Range of indices in the data and indices arrays for this row
552    pub fn row_range(&self, row: usize) -> std::ops::Range<usize> {
553        assert!(row < self.rows, "Row index out of bounds");
554        self.indptr[row]..self.indptr[row + 1]
555    }
556
557    /// Get column indices array
558    pub fn colindices(&self) -> &[usize] {
559        &self.indices
560    }
561
562    /// Extract a contiguous submatrix from the given row and column ranges.
563    ///
564    /// # Arguments
565    ///
566    /// * `row_start` – First row to include (inclusive).
567    /// * `row_end`   – Last row to include (exclusive); clamped to `rows()`.
568    /// * `col_start` – First column to include (inclusive).
569    /// * `col_end`   – Last column to include (exclusive); clamped to `cols()`.
570    ///
571    /// # Returns
572    ///
573    /// A new `CsrMatrix` containing only the specified sub-block.
574    ///
575    /// # Errors
576    ///
577    /// Returns an error if the range is empty (`row_start >= row_end` or
578    /// `col_start >= col_end`) or if the start indices exceed the matrix
579    /// dimensions.
580    ///
581    /// # Examples
582    ///
583    /// ```
584    /// use scirs2_sparse::csr::CsrMatrix;
585    ///
586    /// // 4×4 identity
587    /// let rows = vec![0usize, 1, 2, 3];
588    /// let cols = vec![0usize, 1, 2, 3];
589    /// let data = vec![1.0f64; 4];
590    /// let m = CsrMatrix::new(data, rows, cols, (4, 4)).unwrap();
591    ///
592    /// // Top-left 2×2 block
593    /// let sub = m.submatrix(0, 2, 0, 2).unwrap();
594    /// assert_eq!(sub.rows(), 2);
595    /// assert_eq!(sub.cols(), 2);
596    /// assert_eq!(sub.get(0, 0), 1.0);
597    /// assert_eq!(sub.get(1, 1), 1.0);
598    /// assert_eq!(sub.get(0, 1), 0.0);
599    /// ```
600    pub fn submatrix(
601        &self,
602        row_start: usize,
603        row_end: usize,
604        col_start: usize,
605        col_end: usize,
606    ) -> SparseResult<CsrMatrix<T>> {
607        let row_end = row_end.min(self.rows);
608        let col_end = col_end.min(self.cols);
609        if row_start >= row_end {
610            return Err(SparseError::ValueError(format!(
611                "submatrix: row_start ({}) >= row_end ({})",
612                row_start, row_end
613            )));
614        }
615        if col_start >= col_end {
616            return Err(SparseError::ValueError(format!(
617                "submatrix: col_start ({}) >= col_end ({})",
618                col_start, col_end
619            )));
620        }
621
622        let new_rows = row_end - row_start;
623        let new_cols = col_end - col_start;
624        let mut rows_out = Vec::new();
625        let mut cols_out = Vec::new();
626        let mut data_out = Vec::new();
627
628        for i in row_start..row_end {
629            let range = self.indptr[i]..self.indptr[i + 1];
630            for pos in range {
631                let j = self.indices[pos];
632                if j >= col_start && j < col_end {
633                    rows_out.push(i - row_start);
634                    cols_out.push(j - col_start);
635                    data_out.push(self.data[pos]);
636                }
637            }
638        }
639
640        CsrMatrix::new(data_out, rows_out, cols_out, (new_rows, new_cols))
641    }
642
643    /// Element-wise (Hadamard) product: `C[i,j] = A[i,j] * B[i,j]`.
644    ///
645    /// Only entries that are non-zero in *both* matrices contribute to the
646    /// result — entries missing from either operand are treated as zero.
647    ///
648    /// # Arguments
649    ///
650    /// * `other` – The right-hand matrix; must have the same shape as `self`.
651    ///
652    /// # Errors
653    ///
654    /// Returns an error when the two matrices have different shapes.
655    ///
656    /// # Examples
657    ///
658    /// ```
659    /// use scirs2_sparse::csr::CsrMatrix;
660    ///
661    /// // A = diag(1, 2, 3)  B = diag(4, 5, 6)
662    /// // C = diag(4, 10, 18)
663    /// let a = CsrMatrix::new(
664    ///     vec![1.0f64, 2.0, 3.0],
665    ///     vec![0usize, 1, 2],
666    ///     vec![0usize, 1, 2],
667    ///     (3, 3),
668    /// ).unwrap();
669    /// let b = CsrMatrix::new(
670    ///     vec![4.0f64, 5.0, 6.0],
671    ///     vec![0usize, 1, 2],
672    ///     vec![0usize, 1, 2],
673    ///     (3, 3),
674    /// ).unwrap();
675    /// let c = a.elementwise_mul(&b).unwrap();
676    /// assert_eq!(c.get(0, 0), 4.0);
677    /// assert_eq!(c.get(1, 1), 10.0);
678    /// assert_eq!(c.get(2, 2), 18.0);
679    /// ```
680    pub fn elementwise_mul(&self, other: &CsrMatrix<T>) -> SparseResult<CsrMatrix<T>>
681    where
682        T: std::ops::Mul<Output = T>,
683    {
684        if self.rows != other.rows || self.cols != other.cols {
685            return Err(SparseError::DimensionMismatch {
686                expected: self.rows * self.cols,
687                found: other.rows * other.cols,
688            });
689        }
690
691        let n = self.rows;
692        let nc = self.cols;
693        let mut rows_out = Vec::new();
694        let mut cols_out = Vec::new();
695        let mut data_out = Vec::new();
696
697        // For each row, intersect the non-zero column sets of A and B.
698        // Build a temporary lookup for row i of B using a flat array (safe for
699        // moderate column counts).  For very wide matrices a HashMap would be
700        // preferred, but sparse matrices in CSR format typically have few nnz
701        // per row, so the linear scan below is fast enough.
702        let mut b_row_buf: Vec<(usize, T)> = Vec::new();
703
704        for i in 0..n {
705            // Collect B's non-zeros for row i.
706            b_row_buf.clear();
707            let b_range = other.indptr[i]..other.indptr[i + 1];
708            for pos in b_range {
709                b_row_buf.push((other.indices[pos], other.data[pos]));
710            }
711
712            // Intersect with A's non-zeros for row i.
713            let a_range = self.indptr[i]..self.indptr[i + 1];
714            for pos in a_range {
715                let j = self.indices[pos];
716                let a_val = self.data[pos];
717                // Look up j in b_row_buf (linear scan; typical rows are short).
718                if let Some(&(_, b_val)) = b_row_buf.iter().find(|&&(bj, _)| bj == j) {
719                    let product = a_val * b_val;
720                    if product != T::sparse_zero() {
721                        rows_out.push(i);
722                        cols_out.push(j);
723                        data_out.push(product);
724                    }
725                }
726            }
727        }
728
729        CsrMatrix::new(data_out, rows_out, cols_out, (n, nc))
730    }
731}
732
733impl CsrMatrix<f64> {
734    /// Matrix-vector multiplication
735    ///
736    /// # Arguments
737    ///
738    /// * `vec` - Vector to multiply with
739    ///
740    /// # Returns
741    ///
742    /// * Result of matrix-vector multiplication
743    pub fn dot(&self, vec: &[f64]) -> SparseResult<Vec<f64>> {
744        if vec.len() != self.cols {
745            return Err(SparseError::DimensionMismatch {
746                expected: self.cols,
747                found: vec.len(),
748            });
749        }
750
751        let mut result = vec![0.0; self.rows];
752
753        for (row_idx, result_val) in result.iter_mut().enumerate() {
754            for j in self.indptr[row_idx]..self.indptr[row_idx + 1] {
755                let col_idx = self.indices[j];
756                *result_val += self.data[j] * vec[col_idx];
757            }
758        }
759
760        Ok(result)
761    }
762
763    /// GPU-accelerated matrix-vector multiplication
764    ///
765    /// This method automatically uses GPU acceleration when beneficial,
766    /// falling back to optimized CPU implementation when appropriate.
767    ///
768    /// # Arguments
769    ///
770    /// * `vec` - Vector to multiply with
771    ///
772    /// # Returns
773    ///
774    /// * Result of matrix-vector multiplication
775    ///
776    /// # Examples
777    ///
778    /// ```
779    /// use scirs2_sparse::csr::CsrMatrix;
780    ///
781    /// let rows = vec![0, 0, 1, 2, 2];
782    /// let cols = vec![0, 2, 2, 0, 1];
783    /// let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
784    /// let shape = (3, 3);
785    ///
786    /// let matrix = CsrMatrix::new(data, rows, cols, shape).expect("Operation failed");
787    /// let vec = vec![1.0, 2.0, 3.0];
788    /// let result = matrix.gpu_dot(&vec).expect("Operation failed");
789    /// ```
790    #[allow(dead_code)]
791    pub fn gpu_dot(&self, vec: &[f64]) -> SparseResult<Vec<f64>> {
792        // Use the GpuSpMV implementation
793        let gpu_spmv = crate::gpu_spmv_implementation::GpuSpMV::new()?;
794        gpu_spmv.spmv(
795            self.rows,
796            self.cols,
797            &self.indptr,
798            &self.indices,
799            &self.data,
800            vec,
801        )
802    }
803
804    /// GPU-accelerated matrix-vector multiplication with backend selection
805    ///
806    /// # Arguments
807    ///
808    /// * `vec` - Vector to multiply with
809    /// * `backend` - Preferred GPU backend
810    ///
811    /// # Returns
812    ///
813    /// * Result of matrix-vector multiplication
814    #[allow(dead_code)]
815    pub fn gpu_dot_with_backend(
816        &self,
817        vec: &[f64],
818        backend: scirs2_core::gpu::GpuBackend,
819    ) -> SparseResult<Vec<f64>> {
820        // Use the GpuSpMV implementation with specified backend
821        let gpu_spmv = crate::gpu_spmv_implementation::GpuSpMV::with_backend(backend)?;
822        gpu_spmv.spmv(
823            self.rows,
824            self.cols,
825            &self.indptr,
826            &self.indices,
827            &self.data,
828            vec,
829        )
830    }
831}
832
833impl<T> CsrMatrix<T>
834where
835    T: scirs2_core::numeric::Float
836        + std::fmt::Debug
837        + Copy
838        + Default
839        + GpuDataType
840        + Send
841        + Sync
842        + SparseElement
843        + std::ops::AddAssign
844        + std::ops::Mul<Output = T>
845        + 'static,
846{
847    /// GPU-accelerated matrix-vector multiplication for generic floating-point types
848    ///
849    /// # Arguments
850    ///
851    /// * `vec` - Vector to multiply with
852    ///
853    /// # Returns
854    ///
855    /// * Result of matrix-vector multiplication
856    #[allow(dead_code)]
857    pub fn gpu_dot_generic(&self, vec: &[T]) -> SparseResult<Vec<T>>
858where {
859        // GPU operations fall back to CPU for stability
860        if vec.len() != self.cols {
861            return Err(SparseError::DimensionMismatch {
862                expected: self.cols,
863                found: vec.len(),
864            });
865        }
866
867        let mut result = vec![T::sparse_zero(); self.rows];
868
869        for (row_idx, result_val) in result.iter_mut().enumerate() {
870            let start = self.indptr[row_idx];
871            let end = self.indptr[row_idx + 1];
872
873            for idx in start..end {
874                let col = self.indices[idx];
875                *result_val += self.data[idx] * vec[col];
876            }
877        }
878
879        Ok(result)
880    }
881
882    /// Check if this matrix should benefit from GPU acceleration
883    ///
884    /// # Returns
885    ///
886    /// * `true` if GPU acceleration is likely to provide benefits
887    pub fn should_use_gpu(&self) -> bool {
888        // Use GPU for matrices with significant computation (> 10k non-zeros)
889        // and reasonable sparsity (< 50% dense)
890        let nnz_threshold = 10000;
891        let density = self.nnz() as f64 / (self.rows * self.cols) as f64;
892
893        self.nnz() > nnz_threshold && density < 0.5
894    }
895
896    /// Get GPU backend information
897    ///
898    /// # Returns
899    ///
900    /// * Information about available GPU backends
901    #[allow(dead_code)]
902    pub fn gpu_backend_info() -> SparseResult<(crate::gpu_ops::GpuBackend, String)> {
903        // GPU operations fall back to CPU for stability
904        Ok((crate::gpu_ops::GpuBackend::Cpu, "CPU Fallback".to_string()))
905    }
906}
907
908#[cfg(test)]
909mod tests {
910    use super::*;
911    use approx::assert_relative_eq;
912
913    #[test]
914    fn test_csr_create() {
915        // Create a 3x3 sparse matrix with 5 non-zero elements
916        let rows = vec![0, 0, 1, 2, 2];
917        let cols = vec![0, 2, 2, 0, 1];
918        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
919        let shape = (3, 3);
920
921        let matrix = CsrMatrix::new(data, rows, cols, shape).expect("Operation failed");
922
923        assert_eq!(matrix.shape(), (3, 3));
924        assert_eq!(matrix.nnz(), 5);
925    }
926
927    #[test]
928    fn test_csr_to_dense() {
929        // Create a 3x3 sparse matrix with 5 non-zero elements
930        let rows = vec![0, 0, 1, 2, 2];
931        let cols = vec![0, 2, 2, 0, 1];
932        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
933        let shape = (3, 3);
934
935        let matrix = CsrMatrix::new(data, rows, cols, shape).expect("Operation failed");
936        let dense = matrix.to_dense();
937
938        let expected = vec![
939            vec![1.0, 0.0, 2.0],
940            vec![0.0, 0.0, 3.0],
941            vec![4.0, 5.0, 0.0],
942        ];
943
944        assert_eq!(dense, expected);
945    }
946
947    #[test]
948    fn test_csr_dot() {
949        // Create a 3x3 sparse matrix with 5 non-zero elements
950        let rows = vec![0, 0, 1, 2, 2];
951        let cols = vec![0, 2, 2, 0, 1];
952        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
953        let shape = (3, 3);
954
955        let matrix = CsrMatrix::new(data, rows, cols, shape).expect("Operation failed");
956
957        // Matrix:
958        // [1 0 2]
959        // [0 0 3]
960        // [4 5 0]
961
962        let vec = vec![1.0, 2.0, 3.0];
963        let result = matrix.dot(&vec).expect("Operation failed");
964
965        // Expected:
966        // 1*1 + 0*2 + 2*3 = 7
967        // 0*1 + 0*2 + 3*3 = 9
968        // 4*1 + 5*2 + 0*3 = 14
969        let expected = [7.0, 9.0, 14.0];
970
971        assert_eq!(result.len(), expected.len());
972        for (a, b) in result.iter().zip(expected.iter()) {
973            assert_relative_eq!(a, b, epsilon = 1e-10);
974        }
975    }
976
977    #[test]
978    fn test_csr_transpose() {
979        // Create a 3x3 sparse matrix with 5 non-zero elements
980        let rows = vec![0, 0, 1, 2, 2];
981        let cols = vec![0, 2, 2, 0, 1];
982        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
983        let shape = (3, 3);
984
985        let matrix = CsrMatrix::new(data, rows, cols, shape).expect("Operation failed");
986        let transposed = matrix.transpose();
987
988        assert_eq!(transposed.shape(), (3, 3));
989        assert_eq!(transposed.nnz(), 5);
990
991        let dense = transposed.to_dense();
992        let expected = vec![
993            vec![1.0, 0.0, 4.0],
994            vec![0.0, 0.0, 5.0],
995            vec![2.0, 3.0, 0.0],
996        ];
997
998        assert_eq!(dense, expected);
999    }
1000
1001    #[test]
1002    fn test_gpu_dot() {
1003        // Create a 3x3 sparse matrix
1004        let rows = vec![0, 0, 1, 2, 2];
1005        let cols = vec![0, 2, 2, 0, 1];
1006        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1007        let shape = (3, 3);
1008
1009        let matrix = CsrMatrix::new(data, rows, cols, shape).expect("Operation failed");
1010        let vec = vec![1.0, 2.0, 3.0];
1011
1012        // Test GPU-accelerated SpMV (skip gracefully if GPU is unavailable)
1013        match matrix.gpu_dot(&vec) {
1014            Ok(result) => {
1015                let expected = [7.0, 9.0, 14.0];
1016                assert_eq!(result.len(), expected.len());
1017                for (a, b) in result.iter().zip(expected.iter()) {
1018                    assert_relative_eq!(a, b, epsilon = 1e-10);
1019                }
1020            }
1021            Err(crate::error::SparseError::ComputationError(_))
1022            | Err(crate::error::SparseError::OperationNotSupported(_)) => {
1023                // Acceptable when no GPU is available in CI/local machines
1024            }
1025            Err(e) => panic!("Unexpected error in GPU SpMV: {:?}", e),
1026        }
1027    }
1028
1029    #[test]
1030    fn test_should_use_gpu() {
1031        // Small matrix - should not use GPU
1032        let small_matrix = CsrMatrix::new(vec![1.0, 2.0], vec![0, 1], vec![0, 1], (2, 2))
1033            .expect("Operation failed");
1034        assert!(
1035            !small_matrix.should_use_gpu(),
1036            "Small matrix should not use GPU"
1037        );
1038
1039        // Large sparse matrix - should use GPU
1040        let large_data = vec![1.0; 15000];
1041        let large_rows: Vec<usize> = (0..15000).collect();
1042        let large_cols: Vec<usize> = (0..15000).collect();
1043        let large_matrix = CsrMatrix::new(large_data, large_rows, large_cols, (15000, 15000))
1044            .expect("Operation failed");
1045        assert!(
1046            large_matrix.should_use_gpu(),
1047            "Large sparse matrix should use GPU"
1048        );
1049    }
1050
1051    #[test]
1052    fn test_gpu_backend_info() {
1053        let backend_info = CsrMatrix::<f64>::gpu_backend_info();
1054        assert!(
1055            backend_info.is_ok(),
1056            "Should be able to get GPU backend info"
1057        );
1058
1059        if let Ok((backend, name)) = backend_info {
1060            assert!(!name.is_empty(), "Backend name should not be empty");
1061            // Backend should be one of the supported types
1062            match backend {
1063                crate::gpu_ops::GpuBackend::Cuda
1064                | crate::gpu_ops::GpuBackend::OpenCL
1065                | crate::gpu_ops::GpuBackend::Metal
1066                | crate::gpu_ops::GpuBackend::Cpu
1067                | crate::gpu_ops::GpuBackend::Rocm
1068                | crate::gpu_ops::GpuBackend::Wgpu => {}
1069                #[cfg(not(feature = "gpu"))]
1070                crate::gpu_ops::GpuBackend::Vulkan => {}
1071            }
1072        }
1073    }
1074
1075    #[test]
1076    fn test_gpu_dot_generic_f32() {
1077        // Test with f32 type
1078        let rows = vec![0, 0, 1, 2, 2];
1079        let cols = vec![0, 2, 2, 0, 1];
1080        let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
1081        let shape = (3, 3);
1082
1083        let matrix = CsrMatrix::new(data, rows, cols, shape).expect("Operation failed");
1084        let vec = vec![1.0f32, 2.0, 3.0];
1085
1086        match matrix.gpu_dot_generic(&vec) {
1087            Ok(result) => {
1088                let expected = [7.0f32, 9.0, 14.0];
1089                assert_eq!(result.len(), expected.len());
1090                for (a, b) in result.iter().zip(expected.iter()) {
1091                    assert_relative_eq!(a, b, epsilon = 1e-6);
1092                }
1093            }
1094            Err(crate::error::SparseError::ComputationError(_))
1095            | Err(crate::error::SparseError::OperationNotSupported(_)) => {}
1096            Err(e) => panic!("Unexpected error in generic GPU SpMV: {:?}", e),
1097        }
1098    }
1099}