scirs2_sparse/
csr.rs

1//! Compressed Sparse Row (CSR) matrix format
2//!
3//! This module provides the CSR matrix format implementation, which is
4//! efficient for row operations, matrix-vector multiplication, and more.
5
6use crate::error::{SparseError, SparseResult};
7use scirs2_core::numeric::{Float, Zero};
8use scirs2_core::GpuDataType;
9use std::cmp::PartialEq;
10
11/// Compressed Sparse Row (CSR) matrix
12///
13/// A sparse matrix format that compresses rows, making it efficient for
14/// row operations and matrix-vector multiplication.
15#[derive(Clone)]
16pub struct CsrMatrix<T> {
17    /// Number of rows
18    rows: usize,
19    /// Number of columns
20    cols: usize,
21    /// Row pointers (size rows+1)
22    pub indptr: Vec<usize>,
23    /// Column indices
24    pub indices: Vec<usize>,
25    /// Data values
26    pub data: Vec<T>,
27}
28
29impl<T> CsrMatrix<T>
30where
31    T: Clone + Copy + Zero + PartialEq,
32{
33    /// Get the value at the specified position
34    pub fn get(&self, row: usize, col: usize) -> T {
35        // Check bounds
36        if row >= self.rows || col >= self.cols {
37            return T::zero();
38        }
39
40        // Find the element in the CSR format
41        for j in self.indptr[row]..self.indptr[row + 1] {
42            if self.indices[j] == col {
43                return self.data[j];
44            }
45        }
46
47        // Element not found, return zero
48        T::zero()
49    }
50
51    /// Get the triplets (row indices, column indices, data)
52    pub fn get_triplets(&self) -> (Vec<usize>, Vec<usize>, Vec<T>) {
53        let mut rows = Vec::new();
54        let mut cols = Vec::new();
55        let mut values = Vec::new();
56
57        for i in 0..self.rows {
58            for j in self.indptr[i]..self.indptr[i + 1] {
59                rows.push(i);
60                cols.push(self.indices[j]);
61                values.push(self.data[j]);
62            }
63        }
64
65        (rows, cols, values)
66    }
67    /// Create a new CSR matrix from raw data
68    ///
69    /// # Arguments
70    ///
71    /// * `data` - Vector of non-zero values
72    /// * `rowindices` - Vector of row indices for each non-zero value
73    /// * `colindices` - Vector of column indices for each non-zero value
74    /// * `shape` - Tuple containing the matrix dimensions (rows, cols)
75    ///
76    /// # Returns
77    ///
78    /// * A new CSR matrix
79    ///
80    /// # Examples
81    ///
82    /// ```
83    /// use scirs2_sparse::csr::CsrMatrix;
84    ///
85    /// // Create a 3x3 sparse matrix with 5 non-zero elements
86    /// let rows = vec![0, 0, 1, 2, 2];
87    /// let cols = vec![0, 2, 2, 0, 1];
88    /// let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
89    /// let shape = (3, 3);
90    ///
91    /// let matrix = CsrMatrix::new(data.clone(), rows, cols, shape).unwrap();
92    /// ```
93    pub fn new(
94        data: Vec<T>,
95        rowindices: Vec<usize>,
96        colindices: Vec<usize>,
97        shape: (usize, usize),
98    ) -> SparseResult<Self> {
99        // Validate input data
100        if data.len() != rowindices.len() || data.len() != colindices.len() {
101            return Err(SparseError::DimensionMismatch {
102                expected: data.len(),
103                found: std::cmp::min(rowindices.len(), colindices.len()),
104            });
105        }
106
107        let (rows, cols) = shape;
108
109        // Check indices are within bounds
110        if rowindices.iter().any(|&i| i >= rows) {
111            return Err(SparseError::ValueError(
112                "Row index out of bounds".to_string(),
113            ));
114        }
115
116        if colindices.iter().any(|&i| i >= cols) {
117            return Err(SparseError::ValueError(
118                "Column index out of bounds".to_string(),
119            ));
120        }
121
122        // Convert triplet format to CSR
123        // First, sort by row, then by column
124        let mut triplets: Vec<(usize, usize, T)> = rowindices
125            .into_iter()
126            .zip(colindices)
127            .zip(data)
128            .map(|((r, c), v)| (r, c, v))
129            .collect();
130        triplets.sort_by_key(|&(r, c_, _)| (r, c_));
131
132        // Create indptr, indices, and data arrays
133        let nnz = triplets.len();
134        let mut indptr = vec![0; rows + 1];
135        let mut indices = Vec::with_capacity(nnz);
136        let mut data_out = Vec::with_capacity(nnz);
137
138        // Count elements per row to build indptr
139        for &(r_, _, _) in &triplets {
140            indptr[r_ + 1] += 1;
141        }
142
143        // Compute cumulative sum for indptr
144        for i in 1..=rows {
145            indptr[i] += indptr[i - 1];
146        }
147
148        // Fill indices and data
149        for (_r, c, v) in triplets {
150            indices.push(c);
151            data_out.push(v);
152        }
153
154        Ok(CsrMatrix {
155            rows,
156            cols,
157            indptr,
158            indices,
159            data: data_out,
160        })
161    }
162
163    /// Create a new CSR matrix from raw CSR format
164    ///
165    /// # Arguments
166    ///
167    /// * `data` - Vector of non-zero values
168    /// * `indptr` - Vector of row pointers (size rows+1)
169    /// * `indices` - Vector of column indices
170    /// * `shape` - Tuple containing the matrix dimensions (rows, cols)
171    ///
172    /// # Returns
173    ///
174    /// * A new CSR matrix
175    pub fn from_raw_csr(
176        data: Vec<T>,
177        indptr: Vec<usize>,
178        indices: Vec<usize>,
179        shape: (usize, usize),
180    ) -> SparseResult<Self> {
181        let (rows, cols) = shape;
182
183        // Validate input data
184        if indptr.len() != rows + 1 {
185            return Err(SparseError::DimensionMismatch {
186                expected: rows + 1,
187                found: indptr.len(),
188            });
189        }
190
191        if data.len() != indices.len() {
192            return Err(SparseError::DimensionMismatch {
193                expected: data.len(),
194                found: indices.len(),
195            });
196        }
197
198        // Check if indptr is monotonically increasing
199        for i in 1..indptr.len() {
200            if indptr[i] < indptr[i - 1] {
201                return Err(SparseError::ValueError(
202                    "Row pointer array must be monotonically increasing".to_string(),
203                ));
204            }
205        }
206
207        // Check if the last indptr entry matches the data length
208        if indptr[rows] != data.len() {
209            return Err(SparseError::ValueError(
210                "Last row pointer entry must match data length".to_string(),
211            ));
212        }
213
214        // Check if indices are within bounds
215        if indices.iter().any(|&i| i >= cols) {
216            return Err(SparseError::ValueError(
217                "Column index out of bounds".to_string(),
218            ));
219        }
220
221        Ok(CsrMatrix {
222            rows,
223            cols,
224            indptr,
225            indices,
226            data,
227        })
228    }
229
230    /// Create a new empty CSR matrix
231    ///
232    /// # Arguments
233    ///
234    /// * `shape` - Tuple containing the matrix dimensions (rows, cols)
235    ///
236    /// # Returns
237    ///
238    /// * A new empty CSR matrix
239    pub fn empty(shape: (usize, usize)) -> Self {
240        let (rows, cols) = shape;
241        let indptr = vec![0; rows + 1];
242
243        CsrMatrix {
244            rows,
245            cols,
246            indptr,
247            indices: Vec::new(),
248            data: Vec::new(),
249        }
250    }
251
252    /// Get the number of rows in the matrix
253    pub fn rows(&self) -> usize {
254        self.rows
255    }
256
257    /// Get the number of columns in the matrix
258    pub fn cols(&self) -> usize {
259        self.cols
260    }
261
262    /// Get the shape (dimensions) of the matrix
263    pub fn shape(&self) -> (usize, usize) {
264        (self.rows, self.cols)
265    }
266
267    /// Get the number of non-zero elements in the matrix
268    pub fn nnz(&self) -> usize {
269        self.data.len()
270    }
271
272    /// Convert to dense matrix (as Vec<Vec<T>>)
273    pub fn to_dense(&self) -> Vec<Vec<T>>
274    where
275        T: Zero + Copy,
276    {
277        let mut result = vec![vec![T::zero(); self.cols]; self.rows];
278
279        for (row_idx, row) in result.iter_mut().enumerate() {
280            for j in self.indptr[row_idx]..self.indptr[row_idx + 1] {
281                let col_idx = self.indices[j];
282                row[col_idx] = self.data[j];
283            }
284        }
285
286        result
287    }
288
289    /// Transpose the matrix
290    pub fn transpose(&self) -> Self {
291        // Compute the number of non-zeros per column
292        let mut col_counts = vec![0; self.cols];
293        for &col in &self.indices {
294            col_counts[col] += 1;
295        }
296
297        // Compute column pointers (cumulative sum)
298        let mut col_ptrs = vec![0; self.cols + 1];
299        for i in 0..self.cols {
300            col_ptrs[i + 1] = col_ptrs[i] + col_counts[i];
301        }
302
303        // Fill the transposed matrix
304        let nnz = self.nnz();
305        let mut indices_t = vec![0; nnz];
306        let mut data_t = vec![T::zero(); nnz];
307        let mut col_counts = vec![0; self.cols];
308
309        for row in 0..self.rows {
310            for j in self.indptr[row]..self.indptr[row + 1] {
311                let col = self.indices[j];
312                let dest = col_ptrs[col] + col_counts[col];
313
314                indices_t[dest] = row;
315                data_t[dest] = self.data[j];
316                col_counts[col] += 1;
317            }
318        }
319
320        CsrMatrix {
321            rows: self.cols,
322            cols: self.rows,
323            indptr: col_ptrs,
324            indices: indices_t,
325            data: data_t,
326        }
327    }
328}
329
330impl<
331        T: Clone
332            + Copy
333            + std::ops::AddAssign
334            + std::ops::MulAssign
335            + std::cmp::PartialEq
336            + std::fmt::Debug
337            + scirs2_core::numeric::Zero
338            + std::ops::Add<Output = T>
339            + std::ops::Mul<Output = T>,
340    > CsrMatrix<T>
341{
342    /// Check if matrix is symmetric
343    ///
344    /// # Returns
345    ///
346    /// * `true` if the matrix is symmetric, `false` otherwise
347    pub fn is_symmetric(&self) -> bool {
348        if self.rows != self.cols {
349            return false;
350        }
351
352        // Create a transposed matrix
353        let transposed = self.transpose();
354
355        // Compare the sparsity patterns and values
356        if self.nnz() != transposed.nnz() {
357            return false;
358        }
359
360        // Compare row by row
361        for row in 0..self.rows {
362            let self_start = self.indptr[row];
363            let self_end = self.indptr[row + 1];
364            let trans_start = transposed.indptr[row];
365            let trans_end = transposed.indptr[row + 1];
366
367            if self_end - self_start != trans_end - trans_start {
368                return false;
369            }
370
371            // Create sorted columns and values for this row
372            let mut self_entries: Vec<(usize, &T)> = (self_start..self_end)
373                .map(|j| (self.indices[j], &self.data[j]))
374                .collect();
375            self_entries.sort_by_key(|(col_, _)| *col_);
376
377            let mut trans_entries: Vec<(usize, &T)> = (trans_start..trans_end)
378                .map(|j| (transposed.indices[j], &transposed.data[j]))
379                .collect();
380            trans_entries.sort_by_key(|(col_, _)| *col_);
381
382            // Compare columns and values
383            for i in 0..self_entries.len() {
384                if self_entries[i].0 != trans_entries[i].0
385                    || self_entries[i].1 != trans_entries[i].1
386                {
387                    return false;
388                }
389            }
390        }
391
392        true
393    }
394
395    /// Matrix-matrix multiplication
396    ///
397    /// # Arguments
398    ///
399    /// * `other` - Matrix to multiply with
400    ///
401    /// # Returns
402    ///
403    /// * Result containing the product matrix
404    pub fn matmul(&self, other: &CsrMatrix<T>) -> SparseResult<CsrMatrix<T>> {
405        if self.cols != other.rows {
406            return Err(SparseError::DimensionMismatch {
407                expected: self.cols,
408                found: other.rows,
409            });
410        }
411
412        // For simplicity, we'll implement this using dense operations
413        // In a real implementation, you'd use a more efficient sparse algorithm
414        let a_dense = self.to_dense();
415        let b_dense = other.to_dense();
416
417        let m = self.rows;
418        let n = other.cols;
419        let k = self.cols;
420
421        let mut c_dense = vec![vec![T::zero(); n]; m];
422
423        for (i, c_row) in c_dense.iter_mut().enumerate().take(m) {
424            for (j, val) in c_row.iter_mut().enumerate().take(n) {
425                for (l, &a_val) in a_dense[i].iter().enumerate().take(k) {
426                    let prod = a_val * b_dense[l][j];
427                    *val += prod;
428                }
429            }
430        }
431
432        // Convert back to CSR format
433        let mut rowindices = Vec::new();
434        let mut colindices = Vec::new();
435        let mut values = Vec::new();
436
437        for (i, row) in c_dense.iter().enumerate() {
438            for (j, val) in row.iter().enumerate() {
439                if *val != T::zero() {
440                    rowindices.push(i);
441                    colindices.push(j);
442                    values.push(*val);
443                }
444            }
445        }
446
447        CsrMatrix::new(values, rowindices, colindices, (m, n))
448    }
449
450    /// Get row range for iterating over elements in a row
451    ///
452    /// # Arguments
453    ///
454    /// * `row` - Row index
455    ///
456    /// # Returns
457    ///
458    /// * Range of indices in the data and indices arrays for this row
459    pub fn row_range(&self, row: usize) -> std::ops::Range<usize> {
460        assert!(row < self.rows, "Row index out of bounds");
461        self.indptr[row]..self.indptr[row + 1]
462    }
463
464    /// Get column indices array
465    pub fn colindices(&self) -> &[usize] {
466        &self.indices
467    }
468}
469
470impl CsrMatrix<f64> {
471    /// Matrix-vector multiplication
472    ///
473    /// # Arguments
474    ///
475    /// * `vec` - Vector to multiply with
476    ///
477    /// # Returns
478    ///
479    /// * Result of matrix-vector multiplication
480    pub fn dot(&self, vec: &[f64]) -> SparseResult<Vec<f64>> {
481        if vec.len() != self.cols {
482            return Err(SparseError::DimensionMismatch {
483                expected: self.cols,
484                found: vec.len(),
485            });
486        }
487
488        let mut result = vec![0.0; self.rows];
489
490        for (row_idx, result_val) in result.iter_mut().enumerate() {
491            for j in self.indptr[row_idx]..self.indptr[row_idx + 1] {
492                let col_idx = self.indices[j];
493                *result_val += self.data[j] * vec[col_idx];
494            }
495        }
496
497        Ok(result)
498    }
499
500    /// GPU-accelerated matrix-vector multiplication
501    ///
502    /// This method automatically uses GPU acceleration when beneficial,
503    /// falling back to optimized CPU implementation when appropriate.
504    ///
505    /// # Arguments
506    ///
507    /// * `vec` - Vector to multiply with
508    ///
509    /// # Returns
510    ///
511    /// * Result of matrix-vector multiplication
512    ///
513    /// # Examples
514    ///
515    /// ```
516    /// use scirs2_sparse::csr::CsrMatrix;
517    ///
518    /// let rows = vec![0, 0, 1, 2, 2];
519    /// let cols = vec![0, 2, 2, 0, 1];
520    /// let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
521    /// let shape = (3, 3);
522    ///
523    /// let matrix = CsrMatrix::new(data, rows, cols, shape).unwrap();
524    /// let vec = vec![1.0, 2.0, 3.0];
525    /// let result = matrix.gpu_dot(&vec).unwrap();
526    /// ```
527    #[allow(dead_code)]
528    pub fn gpu_dot(&self, vec: &[f64]) -> SparseResult<Vec<f64>> {
529        // Use the GpuSpMV implementation
530        let gpu_spmv = crate::gpu_spmv_implementation::GpuSpMV::new()?;
531        gpu_spmv.spmv(
532            self.rows,
533            self.cols,
534            &self.indptr,
535            &self.indices,
536            &self.data,
537            vec,
538        )
539    }
540
541    /// GPU-accelerated matrix-vector multiplication with backend selection
542    ///
543    /// # Arguments
544    ///
545    /// * `vec` - Vector to multiply with
546    /// * `backend` - Preferred GPU backend
547    ///
548    /// # Returns
549    ///
550    /// * Result of matrix-vector multiplication
551    #[allow(dead_code)]
552    pub fn gpu_dot_with_backend(
553        &self,
554        vec: &[f64],
555        backend: scirs2_core::gpu::GpuBackend,
556    ) -> SparseResult<Vec<f64>> {
557        // Use the GpuSpMV implementation with specified backend
558        let gpu_spmv = crate::gpu_spmv_implementation::GpuSpMV::with_backend(backend)?;
559        gpu_spmv.spmv(
560            self.rows,
561            self.cols,
562            &self.indptr,
563            &self.indices,
564            &self.data,
565            vec,
566        )
567    }
568}
569
570impl<T> CsrMatrix<T>
571where
572    T: scirs2_core::numeric::Float
573        + std::fmt::Debug
574        + Copy
575        + Default
576        + GpuDataType
577        + Send
578        + Sync
579        + 'static,
580{
581    /// GPU-accelerated matrix-vector multiplication for generic floating-point types
582    ///
583    /// # Arguments
584    ///
585    /// * `vec` - Vector to multiply with
586    ///
587    /// # Returns
588    ///
589    /// * Result of matrix-vector multiplication
590    #[allow(dead_code)]
591    pub fn gpu_dot_generic(&self, vec: &[T]) -> SparseResult<Vec<T>>
592    where
593        T: Float + std::ops::AddAssign + Copy + Default + std::iter::Sum,
594    {
595        // GPU operations fall back to CPU for stability
596        if vec.len() != self.cols {
597            return Err(SparseError::DimensionMismatch {
598                expected: self.cols,
599                found: vec.len(),
600            });
601        }
602
603        let mut result = vec![T::zero(); self.rows];
604
605        for (row_idx, result_val) in result.iter_mut().enumerate() {
606            let start = self.indptr[row_idx];
607            let end = self.indptr[row_idx + 1];
608
609            for idx in start..end {
610                let col = self.indices[idx];
611                *result_val += self.data[idx] * vec[col];
612            }
613        }
614
615        Ok(result)
616    }
617
618    /// Check if this matrix should benefit from GPU acceleration
619    ///
620    /// # Returns
621    ///
622    /// * `true` if GPU acceleration is likely to provide benefits
623    pub fn should_use_gpu(&self) -> bool {
624        // Use GPU for matrices with significant computation (> 10k non-zeros)
625        // and reasonable sparsity (< 50% dense)
626        let nnz_threshold = 10000;
627        let density = self.nnz() as f64 / (self.rows * self.cols) as f64;
628
629        self.nnz() > nnz_threshold && density < 0.5
630    }
631
632    /// Get GPU backend information
633    ///
634    /// # Returns
635    ///
636    /// * Information about available GPU backends
637    #[allow(dead_code)]
638    pub fn gpu_backend_info() -> SparseResult<(crate::gpu_ops::GpuBackend, String)> {
639        // GPU operations fall back to CPU for stability
640        Ok((crate::gpu_ops::GpuBackend::Cpu, "CPU Fallback".to_string()))
641    }
642}
643
644#[cfg(test)]
645mod tests {
646    use super::*;
647    use approx::assert_relative_eq;
648
649    #[test]
650    fn test_csr_create() {
651        // Create a 3x3 sparse matrix with 5 non-zero elements
652        let rows = vec![0, 0, 1, 2, 2];
653        let cols = vec![0, 2, 2, 0, 1];
654        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
655        let shape = (3, 3);
656
657        let matrix = CsrMatrix::new(data, rows, cols, shape).unwrap();
658
659        assert_eq!(matrix.shape(), (3, 3));
660        assert_eq!(matrix.nnz(), 5);
661    }
662
663    #[test]
664    fn test_csr_to_dense() {
665        // Create a 3x3 sparse matrix with 5 non-zero elements
666        let rows = vec![0, 0, 1, 2, 2];
667        let cols = vec![0, 2, 2, 0, 1];
668        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
669        let shape = (3, 3);
670
671        let matrix = CsrMatrix::new(data, rows, cols, shape).unwrap();
672        let dense = matrix.to_dense();
673
674        let expected = vec![
675            vec![1.0, 0.0, 2.0],
676            vec![0.0, 0.0, 3.0],
677            vec![4.0, 5.0, 0.0],
678        ];
679
680        assert_eq!(dense, expected);
681    }
682
683    #[test]
684    fn test_csr_dot() {
685        // Create a 3x3 sparse matrix with 5 non-zero elements
686        let rows = vec![0, 0, 1, 2, 2];
687        let cols = vec![0, 2, 2, 0, 1];
688        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
689        let shape = (3, 3);
690
691        let matrix = CsrMatrix::new(data, rows, cols, shape).unwrap();
692
693        // Matrix:
694        // [1 0 2]
695        // [0 0 3]
696        // [4 5 0]
697
698        let vec = vec![1.0, 2.0, 3.0];
699        let result = matrix.dot(&vec).unwrap();
700
701        // Expected:
702        // 1*1 + 0*2 + 2*3 = 7
703        // 0*1 + 0*2 + 3*3 = 9
704        // 4*1 + 5*2 + 0*3 = 14
705        let expected = [7.0, 9.0, 14.0];
706
707        assert_eq!(result.len(), expected.len());
708        for (a, b) in result.iter().zip(expected.iter()) {
709            assert_relative_eq!(a, b, epsilon = 1e-10);
710        }
711    }
712
713    #[test]
714    fn test_csr_transpose() {
715        // Create a 3x3 sparse matrix with 5 non-zero elements
716        let rows = vec![0, 0, 1, 2, 2];
717        let cols = vec![0, 2, 2, 0, 1];
718        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
719        let shape = (3, 3);
720
721        let matrix = CsrMatrix::new(data, rows, cols, shape).unwrap();
722        let transposed = matrix.transpose();
723
724        assert_eq!(transposed.shape(), (3, 3));
725        assert_eq!(transposed.nnz(), 5);
726
727        let dense = transposed.to_dense();
728        let expected = vec![
729            vec![1.0, 0.0, 4.0],
730            vec![0.0, 0.0, 5.0],
731            vec![2.0, 3.0, 0.0],
732        ];
733
734        assert_eq!(dense, expected);
735    }
736
737    #[test]
738    fn test_gpu_dot() {
739        // Create a 3x3 sparse matrix
740        let rows = vec![0, 0, 1, 2, 2];
741        let cols = vec![0, 2, 2, 0, 1];
742        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
743        let shape = (3, 3);
744
745        let matrix = CsrMatrix::new(data, rows, cols, shape).unwrap();
746        let vec = vec![1.0, 2.0, 3.0];
747
748        // Test GPU-accelerated SpMV (skip gracefully if GPU is unavailable)
749        match matrix.gpu_dot(&vec) {
750            Ok(result) => {
751                let expected = [7.0, 9.0, 14.0];
752                assert_eq!(result.len(), expected.len());
753                for (a, b) in result.iter().zip(expected.iter()) {
754                    assert_relative_eq!(a, b, epsilon = 1e-10);
755                }
756            }
757            Err(crate::error::SparseError::ComputationError(_))
758            | Err(crate::error::SparseError::OperationNotSupported(_)) => {
759                // Acceptable when no GPU is available in CI/local machines
760            }
761            Err(e) => panic!("Unexpected error in GPU SpMV: {:?}", e),
762        }
763    }
764
765    #[test]
766    fn test_should_use_gpu() {
767        // Small matrix - should not use GPU
768        let small_matrix = CsrMatrix::new(vec![1.0, 2.0], vec![0, 1], vec![0, 1], (2, 2)).unwrap();
769        assert!(
770            !small_matrix.should_use_gpu(),
771            "Small matrix should not use GPU"
772        );
773
774        // Large sparse matrix - should use GPU
775        let large_data = vec![1.0; 15000];
776        let large_rows: Vec<usize> = (0..15000).collect();
777        let large_cols: Vec<usize> = (0..15000).collect();
778        let large_matrix =
779            CsrMatrix::new(large_data, large_rows, large_cols, (15000, 15000)).unwrap();
780        assert!(
781            large_matrix.should_use_gpu(),
782            "Large sparse matrix should use GPU"
783        );
784    }
785
786    #[test]
787    fn test_gpu_backend_info() {
788        let backend_info = CsrMatrix::<f64>::gpu_backend_info();
789        assert!(
790            backend_info.is_ok(),
791            "Should be able to get GPU backend info"
792        );
793
794        if let Ok((backend, name)) = backend_info {
795            assert!(!name.is_empty(), "Backend name should not be empty");
796            // Backend should be one of the supported types
797            match backend {
798                crate::gpu_ops::GpuBackend::Cuda
799                | crate::gpu_ops::GpuBackend::OpenCL
800                | crate::gpu_ops::GpuBackend::Metal
801                | crate::gpu_ops::GpuBackend::Cpu
802                | crate::gpu_ops::GpuBackend::Rocm
803                | crate::gpu_ops::GpuBackend::Wgpu => {}
804            }
805        }
806    }
807
808    #[test]
809    fn test_gpu_dot_generic_f32() {
810        // Test with f32 type
811        let rows = vec![0, 0, 1, 2, 2];
812        let cols = vec![0, 2, 2, 0, 1];
813        let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
814        let shape = (3, 3);
815
816        let matrix = CsrMatrix::new(data, rows, cols, shape).unwrap();
817        let vec = vec![1.0f32, 2.0, 3.0];
818
819        match matrix.gpu_dot_generic(&vec) {
820            Ok(result) => {
821                let expected = [7.0f32, 9.0, 14.0];
822                assert_eq!(result.len(), expected.len());
823                for (a, b) in result.iter().zip(expected.iter()) {
824                    assert_relative_eq!(a, b, epsilon = 1e-6);
825                }
826            }
827            Err(crate::error::SparseError::ComputationError(_))
828            | Err(crate::error::SparseError::OperationNotSupported(_)) => {}
829            Err(e) => panic!("Unexpected error in generic GPU SpMV: {:?}", e),
830        }
831    }
832}