Skip to main content

scivex_core/linalg/
sparse.rs

1//! Sparse matrix formats: COO, CSR, CSC.
2//!
3//! Three standard sparse representations with conversions between them and
4//! interoperation with dense [`Tensor`] for I/O and matrix-vector products.
5
6use crate::Scalar;
7use crate::error::{CoreError, Result};
8use crate::tensor::Tensor;
9
10// ======================================================================
11// COO (Coordinate) format
12// ======================================================================
13
14/// Sparse matrix in COO (coordinate / triplet) format.
15///
16/// Stores (row, col, value) triplets. Duplicate entries are summed during
17/// conversion to CSR/CSC.
18///
19/// # Examples
20///
21/// ```
22/// # use scivex_core::linalg::sparse::CooMatrix;
23/// let mut coo = CooMatrix::<f64>::new(3, 3);
24/// coo.push(0, 0, 1.0).unwrap();
25/// coo.push(1, 1, 2.0).unwrap();
26/// assert_eq!(coo.nnz(), 2);
27/// ```
28#[cfg_attr(
29    feature = "serde-support",
30    derive(serde::Serialize, serde::Deserialize)
31)]
32#[derive(Debug, Clone)]
33pub struct CooMatrix<T: Scalar> {
34    rows: Vec<usize>,
35    cols: Vec<usize>,
36    values: Vec<T>,
37    nrows: usize,
38    ncols: usize,
39}
40
41impl<T: Scalar> CooMatrix<T> {
42    /// Create an empty COO matrix with the given dimensions.
43    ///
44    /// # Examples
45    ///
46    /// ```
47    /// # use scivex_core::linalg::sparse::CooMatrix;
48    /// let coo = CooMatrix::<f64>::new(3, 3);
49    /// assert_eq!(coo.nnz(), 0);
50    /// assert_eq!(coo.shape(), (3, 3));
51    /// ```
52    pub fn new(nrows: usize, ncols: usize) -> Self {
53        Self {
54            rows: Vec::new(),
55            cols: Vec::new(),
56            values: Vec::new(),
57            nrows,
58            ncols,
59        }
60    }
61
62    /// Build a COO matrix from triplet arrays.
63    ///
64    /// # Examples
65    ///
66    /// ```
67    /// # use scivex_core::linalg::sparse::CooMatrix;
68    /// let coo = CooMatrix::from_triplets(
69    ///     2, 2,
70    ///     vec![0, 1], vec![0, 1], vec![1.0, 2.0],
71    /// ).unwrap();
72    /// assert_eq!(coo.nnz(), 2);
73    /// ```
74    pub fn from_triplets(
75        nrows: usize,
76        ncols: usize,
77        rows: Vec<usize>,
78        cols: Vec<usize>,
79        values: Vec<T>,
80    ) -> Result<Self> {
81        if rows.len() != cols.len() || rows.len() != values.len() {
82            return Err(CoreError::InvalidArgument {
83                reason: "rows, cols, and values must have the same length",
84            });
85        }
86        for (&r, &c) in rows.iter().zip(cols.iter()) {
87            if r >= nrows || c >= ncols {
88                return Err(CoreError::InvalidArgument {
89                    reason: "index out of bounds for matrix dimensions",
90                });
91            }
92        }
93        Ok(Self {
94            rows,
95            cols,
96            values,
97            nrows,
98            ncols,
99        })
100    }
101
102    /// Append a single entry.
103    ///
104    /// # Examples
105    ///
106    /// ```
107    /// # use scivex_core::linalg::sparse::CooMatrix;
108    /// let mut coo = CooMatrix::<f64>::new(2, 2);
109    /// coo.push(0, 1, 3.5).unwrap();
110    /// assert_eq!(coo.nnz(), 1);
111    /// ```
112    pub fn push(&mut self, row: usize, col: usize, value: T) -> Result<()> {
113        if row >= self.nrows || col >= self.ncols {
114            return Err(CoreError::InvalidArgument {
115                reason: "index out of bounds for matrix dimensions",
116            });
117        }
118        self.rows.push(row);
119        self.cols.push(col);
120        self.values.push(value);
121        Ok(())
122    }
123
124    /// Number of rows.
125    #[inline]
126    pub fn nrows(&self) -> usize {
127        self.nrows
128    }
129
130    /// Number of columns.
131    #[inline]
132    pub fn ncols(&self) -> usize {
133        self.ncols
134    }
135
136    /// Number of stored entries (may include duplicates).
137    #[inline]
138    pub fn nnz(&self) -> usize {
139        self.values.len()
140    }
141
142    /// Shape as `(nrows, ncols)`.
143    #[inline]
144    pub fn shape(&self) -> (usize, usize) {
145        (self.nrows, self.ncols)
146    }
147
148    /// Convert to a dense 2-D tensor. Duplicate entries are summed.
149    ///
150    /// # Examples
151    ///
152    /// ```
153    /// # use scivex_core::linalg::sparse::CooMatrix;
154    /// let coo = CooMatrix::from_triplets(
155    ///     2, 2,
156    ///     vec![0, 1], vec![0, 1], vec![1.0, 2.0],
157    /// ).unwrap();
158    /// let dense = coo.to_dense();
159    /// assert_eq!(dense.shape(), &[2, 2]);
160    /// assert_eq!(dense.as_slice(), &[1.0, 0.0, 0.0, 2.0]);
161    /// ```
162    pub fn to_dense(&self) -> Tensor<T> {
163        let mut data = vec![T::zero(); self.nrows * self.ncols];
164        for ((&r, &c), &v) in self
165            .rows
166            .iter()
167            .zip(self.cols.iter())
168            .zip(self.values.iter())
169        {
170            data[r * self.ncols + c] += v;
171        }
172        // SAFETY: data has exactly nrows*ncols elements, matching shape [nrows, ncols].
173        Tensor::from_vec(data, vec![self.nrows, self.ncols])
174            .expect("dense data length equals nrows*ncols by construction")
175    }
176
177    /// Convert to CSR format. Duplicate entries at the same position are summed.
178    ///
179    /// # Examples
180    ///
181    /// ```
182    /// # use scivex_core::linalg::sparse::CooMatrix;
183    /// let coo = CooMatrix::from_triplets(
184    ///     2, 2,
185    ///     vec![0, 1], vec![0, 1], vec![1.0, 2.0],
186    /// ).unwrap();
187    /// let csr = coo.to_csr();
188    /// assert_eq!(csr.nnz(), 2);
189    /// ```
190    pub fn to_csr(&self) -> CsrMatrix<T> {
191        // Count entries per row
192        let mut row_counts = vec![0usize; self.nrows + 1];
193        for &r in &self.rows {
194            row_counts[r + 1] += 1;
195        }
196        // Prefix sum -> row_ptr
197        for i in 1..=self.nrows {
198            row_counts[i] += row_counts[i - 1];
199        }
200
201        let nnz = self.values.len();
202        let mut col_idx = vec![0usize; nnz];
203        let mut values = vec![T::zero(); nnz];
204        let mut offset = row_counts.clone();
205
206        for ((&r, &c), &v) in self
207            .rows
208            .iter()
209            .zip(self.cols.iter())
210            .zip(self.values.iter())
211        {
212            let pos = offset[r];
213            col_idx[pos] = c;
214            values[pos] = v;
215            offset[r] += 1;
216        }
217
218        // Sort within each row by column index and sum duplicates
219        let mut result = CsrMatrix {
220            row_ptr: row_counts,
221            col_idx,
222            values,
223            nrows: self.nrows,
224            ncols: self.ncols,
225        };
226        result.sort_and_sum_duplicates();
227        result
228    }
229
230    /// Convert to CSC format. Duplicate entries at the same position are summed.
231    ///
232    /// # Examples
233    ///
234    /// ```
235    /// # use scivex_core::linalg::sparse::CooMatrix;
236    /// let coo = CooMatrix::from_triplets(
237    ///     2, 2,
238    ///     vec![0, 1], vec![0, 1], vec![1.0, 2.0],
239    /// ).unwrap();
240    /// let csc = coo.to_csc();
241    /// assert_eq!(csc.nnz(), 2);
242    /// ```
243    pub fn to_csc(&self) -> CscMatrix<T> {
244        // Count entries per column
245        let mut col_counts = vec![0usize; self.ncols + 1];
246        for &c in &self.cols {
247            col_counts[c + 1] += 1;
248        }
249        for i in 1..=self.ncols {
250            col_counts[i] += col_counts[i - 1];
251        }
252
253        let nnz = self.values.len();
254        let mut row_idx = vec![0usize; nnz];
255        let mut values = vec![T::zero(); nnz];
256        let mut offset = col_counts.clone();
257
258        for ((&r, &c), &v) in self
259            .rows
260            .iter()
261            .zip(self.cols.iter())
262            .zip(self.values.iter())
263        {
264            let pos = offset[c];
265            row_idx[pos] = r;
266            values[pos] = v;
267            offset[c] += 1;
268        }
269
270        let mut result = CscMatrix {
271            col_ptr: col_counts,
272            row_idx,
273            values,
274            nrows: self.nrows,
275            ncols: self.ncols,
276        };
277        result.sort_and_sum_duplicates();
278        result
279    }
280}
281
282// ======================================================================
283// CSR (Compressed Sparse Row) format
284// ======================================================================
285
286/// Sparse matrix in CSR (Compressed Sparse Row) format.
287///
288/// # Examples
289///
290/// ```
291/// # use scivex_core::linalg::sparse::CsrMatrix;
292/// let csr = CsrMatrix::from_triplets(
293///     2, 2,
294///     vec![0, 1], vec![0, 1], vec![1.0, 2.0],
295/// ).unwrap();
296/// assert_eq!(csr.nnz(), 2);
297/// assert_eq!(*csr.get(0, 0).unwrap(), 1.0);
298/// ```
299#[cfg_attr(
300    feature = "serde-support",
301    derive(serde::Serialize, serde::Deserialize)
302)]
303#[derive(Debug, Clone)]
304pub struct CsrMatrix<T: Scalar> {
305    row_ptr: Vec<usize>,
306    col_idx: Vec<usize>,
307    values: Vec<T>,
308    nrows: usize,
309    ncols: usize,
310}
311
312impl<T: Scalar> CsrMatrix<T> {
313    /// Create an empty CSR matrix.
314    ///
315    /// # Examples
316    ///
317    /// ```
318    /// # use scivex_core::linalg::sparse::CsrMatrix;
319    /// let csr = CsrMatrix::<f64>::new(3, 3);
320    /// assert_eq!(csr.nnz(), 0);
321    /// ```
322    pub fn new(nrows: usize, ncols: usize) -> Self {
323        Self {
324            row_ptr: vec![0; nrows + 1],
325            col_idx: Vec::new(),
326            values: Vec::new(),
327            nrows,
328            ncols,
329        }
330    }
331
332    /// Build CSR from triplet data.
333    ///
334    /// # Examples
335    ///
336    /// ```
337    /// # use scivex_core::linalg::sparse::CsrMatrix;
338    /// let csr = CsrMatrix::from_triplets(
339    ///     2, 2,
340    ///     vec![0, 1], vec![1, 0], vec![3.0, 4.0],
341    /// ).unwrap();
342    /// assert_eq!(csr.nnz(), 2);
343    /// ```
344    pub fn from_triplets(
345        nrows: usize,
346        ncols: usize,
347        rows: Vec<usize>,
348        cols: Vec<usize>,
349        values: Vec<T>,
350    ) -> Result<Self> {
351        let coo = CooMatrix::from_triplets(nrows, ncols, rows, cols, values)?;
352        Ok(coo.to_csr())
353    }
354
355    /// Build CSR from a dense 2-D tensor, dropping zero entries.
356    ///
357    /// # Examples
358    ///
359    /// ```
360    /// # use scivex_core::tensor::Tensor;
361    /// # use scivex_core::linalg::sparse::CsrMatrix;
362    /// let dense = Tensor::from_vec(vec![1.0, 0.0, 0.0, 2.0], vec![2, 2]).unwrap();
363    /// let csr = CsrMatrix::from_dense(&dense).unwrap();
364    /// assert_eq!(csr.nnz(), 2);
365    /// ```
366    pub fn from_dense(tensor: &Tensor<T>) -> Result<Self> {
367        if tensor.ndim() != 2 {
368            return Err(CoreError::InvalidArgument {
369                reason: "from_dense requires a 2-D tensor",
370            });
371        }
372        let nrows = tensor.shape()[0];
373        let ncols = tensor.shape()[1];
374        let data = tensor.as_slice();
375
376        let mut row_ptr = vec![0usize; nrows + 1];
377        let mut col_idx = Vec::new();
378        let mut values = Vec::new();
379
380        for r in 0..nrows {
381            for c in 0..ncols {
382                let v = data[r * ncols + c];
383                if v != T::zero() {
384                    col_idx.push(c);
385                    values.push(v);
386                }
387            }
388            row_ptr[r + 1] = values.len();
389        }
390
391        Ok(Self {
392            row_ptr,
393            col_idx,
394            values,
395            nrows,
396            ncols,
397        })
398    }
399
400    /// Number of rows.
401    #[inline]
402    pub fn nrows(&self) -> usize {
403        self.nrows
404    }
405
406    /// Number of columns.
407    #[inline]
408    pub fn ncols(&self) -> usize {
409        self.ncols
410    }
411
412    /// Number of stored non-zero entries.
413    #[inline]
414    pub fn nnz(&self) -> usize {
415        self.values.len()
416    }
417
418    /// Shape as `(nrows, ncols)`.
419    #[inline]
420    pub fn shape(&self) -> (usize, usize) {
421        (self.nrows, self.ncols)
422    }
423
424    /// Get the value at `(row, col)`, or `None` if not stored.
425    ///
426    /// # Examples
427    ///
428    /// ```
429    /// # use scivex_core::linalg::sparse::CsrMatrix;
430    /// let csr = CsrMatrix::from_triplets(
431    ///     2, 2, vec![0], vec![1], vec![5.0],
432    /// ).unwrap();
433    /// assert_eq!(*csr.get(0, 1).unwrap(), 5.0);
434    /// assert!(csr.get(0, 0).is_none());
435    /// ```
436    pub fn get(&self, row: usize, col: usize) -> Option<&T> {
437        if row >= self.nrows || col >= self.ncols {
438            return None;
439        }
440        let start = self.row_ptr[row];
441        let end = self.row_ptr[row + 1];
442        self.col_idx[start..end]
443            .binary_search(&col)
444            .ok()
445            .map(|pos| &self.values[start + pos])
446    }
447
448    /// Convert to a dense 2-D tensor.
449    ///
450    /// # Examples
451    ///
452    /// ```
453    /// # use scivex_core::linalg::sparse::CsrMatrix;
454    /// let csr = CsrMatrix::from_triplets(
455    ///     2, 2, vec![0, 1], vec![0, 1], vec![1.0, 2.0],
456    /// ).unwrap();
457    /// let dense = csr.to_dense();
458    /// assert_eq!(dense.as_slice(), &[1.0, 0.0, 0.0, 2.0]);
459    /// ```
460    pub fn to_dense(&self) -> Tensor<T> {
461        let mut data = vec![T::zero(); self.nrows * self.ncols];
462        for r in 0..self.nrows {
463            let start = self.row_ptr[r];
464            let end = self.row_ptr[r + 1];
465            for idx in start..end {
466                let c = self.col_idx[idx];
467                data[r * self.ncols + c] = self.values[idx];
468            }
469        }
470        // SAFETY: data has exactly nrows*ncols elements, matching shape [nrows, ncols].
471        Tensor::from_vec(data, vec![self.nrows, self.ncols])
472            .expect("dense data length equals nrows*ncols by construction")
473    }
474
475    /// Sparse matrix x dense vector multiplication.
476    ///
477    /// `x` must be a 1-D tensor of length `ncols`.
478    ///
479    /// # Examples
480    ///
481    /// ```
482    /// # use scivex_core::tensor::Tensor;
483    /// # use scivex_core::linalg::sparse::CsrMatrix;
484    /// let csr = CsrMatrix::from_triplets(
485    ///     2, 2, vec![0, 1], vec![0, 1], vec![3.0, 4.0],
486    /// ).unwrap();
487    /// let x = Tensor::from_vec(vec![1.0, 2.0], vec![2]).unwrap();
488    /// let y = csr.matvec(&x).unwrap();
489    /// assert_eq!(y.as_slice(), &[3.0, 8.0]);
490    /// ```
491    pub fn matvec(&self, x: &Tensor<T>) -> Result<Tensor<T>> {
492        if x.ndim() != 1 || x.numel() != self.ncols {
493            return Err(CoreError::DimensionMismatch {
494                expected: vec![self.ncols],
495                got: x.shape().to_vec(),
496            });
497        }
498        let xdata = x.as_slice();
499        let mut result = vec![T::zero(); self.nrows];
500
501        for (r, dest) in result.iter_mut().enumerate() {
502            let start = self.row_ptr[r];
503            let end = self.row_ptr[r + 1];
504            let mut acc = T::zero();
505            for idx in start..end {
506                acc += self.values[idx] * xdata[self.col_idx[idx]];
507            }
508            *dest = acc;
509        }
510
511        Tensor::from_vec(result, vec![self.nrows])
512    }
513
514    /// Transpose, returning a CSC matrix.
515    ///
516    /// # Examples
517    ///
518    /// ```
519    /// # use scivex_core::linalg::sparse::CsrMatrix;
520    /// let csr = CsrMatrix::from_triplets(
521    ///     2, 3, vec![0, 1], vec![1, 2], vec![1.0, 2.0],
522    /// ).unwrap();
523    /// let csc = csr.transpose();
524    /// assert_eq!(csc.nrows(), 3);
525    /// assert_eq!(csc.ncols(), 2);
526    /// ```
527    pub fn transpose(&self) -> CscMatrix<T> {
528        CscMatrix {
529            col_ptr: self.row_ptr.clone(),
530            row_idx: self.col_idx.clone(),
531            values: self.values.clone(),
532            nrows: self.ncols,
533            ncols: self.nrows,
534        }
535    }
536
537    /// Convert to COO format.
538    ///
539    /// # Examples
540    ///
541    /// ```
542    /// # use scivex_core::linalg::sparse::CsrMatrix;
543    /// let csr = CsrMatrix::from_triplets(
544    ///     2, 2, vec![0, 1], vec![0, 1], vec![1.0, 2.0],
545    /// ).unwrap();
546    /// let coo = csr.to_coo();
547    /// assert_eq!(coo.nnz(), 2);
548    /// ```
549    pub fn to_coo(&self) -> CooMatrix<T> {
550        let mut rows = Vec::with_capacity(self.nnz());
551        let mut cols = Vec::with_capacity(self.nnz());
552        let mut values = Vec::with_capacity(self.nnz());
553
554        for r in 0..self.nrows {
555            let start = self.row_ptr[r];
556            let end = self.row_ptr[r + 1];
557            for idx in start..end {
558                rows.push(r);
559                cols.push(self.col_idx[idx]);
560                values.push(self.values[idx]);
561            }
562        }
563
564        CooMatrix {
565            rows,
566            cols,
567            values,
568            nrows: self.nrows,
569            ncols: self.ncols,
570        }
571    }
572
573    /// Convert to CSC format.
574    ///
575    /// # Examples
576    ///
577    /// ```
578    /// # use scivex_core::linalg::sparse::CsrMatrix;
579    /// let csr = CsrMatrix::from_triplets(
580    ///     2, 2, vec![0, 1], vec![0, 1], vec![1.0, 2.0],
581    /// ).unwrap();
582    /// let csc = csr.to_csc();
583    /// assert_eq!(csc.nnz(), 2);
584    /// ```
585    pub fn to_csc(&self) -> CscMatrix<T> {
586        self.to_coo().to_csc()
587    }
588
589    /// Sort column indices within each row and sum duplicate entries.
590    fn sort_and_sum_duplicates(&mut self) {
591        for r in 0..self.nrows {
592            let start = self.row_ptr[r];
593            let end = self.row_ptr[r + 1];
594            if start == end {
595                continue;
596            }
597
598            // Sort by column index using a permutation
599            let len = end - start;
600            let mut perm: Vec<usize> = (0..len).collect();
601            perm.sort_unstable_by_key(|&i| self.col_idx[start + i]);
602
603            let old_cols: Vec<usize> = self.col_idx[start..end].to_vec();
604            let old_vals: Vec<T> = self.values[start..end].to_vec();
605            for (j, &p) in perm.iter().enumerate() {
606                self.col_idx[start + j] = old_cols[p];
607                self.values[start + j] = old_vals[p];
608            }
609
610            // Sum duplicates in-place
611            let mut write = start;
612            for read in (start + 1)..end {
613                if self.col_idx[read] == self.col_idx[write] {
614                    let v = self.values[read];
615                    self.values[write] += v;
616                } else {
617                    write += 1;
618                    self.col_idx[write] = self.col_idx[read];
619                    self.values[write] = self.values[read];
620                }
621            }
622            let new_end = write + 1;
623
624            // Shift subsequent data if duplicates were removed
625            if new_end < end {
626                let removed = end - new_end;
627                let total_nnz = self.col_idx.len();
628                self.col_idx.copy_within(end..total_nnz, new_end);
629                self.col_idx.truncate(total_nnz - removed);
630                let total_vals = self.values.len();
631                self.values.copy_within(end..total_vals, new_end);
632                self.values.truncate(total_vals - removed);
633
634                for i in (r + 1)..=self.nrows {
635                    self.row_ptr[i] -= removed;
636                }
637            }
638        }
639    }
640}
641
642// ======================================================================
643// CSC (Compressed Sparse Column) format
644// ======================================================================
645
646/// Sparse matrix in CSC (Compressed Sparse Column) format.
647///
648/// # Examples
649///
650/// ```
651/// # use scivex_core::linalg::sparse::CscMatrix;
652/// let csc = CscMatrix::from_triplets(
653///     2, 2,
654///     vec![0, 1], vec![0, 1], vec![1.0, 2.0],
655/// ).unwrap();
656/// assert_eq!(csc.nnz(), 2);
657/// ```
658#[cfg_attr(
659    feature = "serde-support",
660    derive(serde::Serialize, serde::Deserialize)
661)]
662#[derive(Debug, Clone)]
663pub struct CscMatrix<T: Scalar> {
664    col_ptr: Vec<usize>,
665    row_idx: Vec<usize>,
666    values: Vec<T>,
667    nrows: usize,
668    ncols: usize,
669}
670
671impl<T: Scalar> CscMatrix<T> {
672    /// Create an empty CSC matrix.
673    ///
674    /// # Examples
675    ///
676    /// ```
677    /// # use scivex_core::linalg::sparse::CscMatrix;
678    /// let csc = CscMatrix::<f64>::new(3, 3);
679    /// assert_eq!(csc.nnz(), 0);
680    /// ```
681    pub fn new(nrows: usize, ncols: usize) -> Self {
682        Self {
683            col_ptr: vec![0; ncols + 1],
684            row_idx: Vec::new(),
685            values: Vec::new(),
686            nrows,
687            ncols,
688        }
689    }
690
691    /// Build CSC from triplet data.
692    ///
693    /// # Examples
694    ///
695    /// ```
696    /// # use scivex_core::linalg::sparse::CscMatrix;
697    /// let csc = CscMatrix::from_triplets(
698    ///     2, 2,
699    ///     vec![0, 1], vec![0, 1], vec![1.0, 2.0],
700    /// ).unwrap();
701    /// assert_eq!(csc.nnz(), 2);
702    /// ```
703    pub fn from_triplets(
704        nrows: usize,
705        ncols: usize,
706        rows: Vec<usize>,
707        cols: Vec<usize>,
708        values: Vec<T>,
709    ) -> Result<Self> {
710        let coo = CooMatrix::from_triplets(nrows, ncols, rows, cols, values)?;
711        Ok(coo.to_csc())
712    }
713
714    /// Number of rows.
715    #[inline]
716    pub fn nrows(&self) -> usize {
717        self.nrows
718    }
719
720    /// Number of columns.
721    #[inline]
722    pub fn ncols(&self) -> usize {
723        self.ncols
724    }
725
726    /// Number of stored non-zero entries.
727    #[inline]
728    pub fn nnz(&self) -> usize {
729        self.values.len()
730    }
731
732    /// Shape as `(nrows, ncols)`.
733    #[inline]
734    pub fn shape(&self) -> (usize, usize) {
735        (self.nrows, self.ncols)
736    }
737
738    /// Convert to a dense 2-D tensor.
739    ///
740    /// # Examples
741    ///
742    /// ```
743    /// # use scivex_core::linalg::sparse::CscMatrix;
744    /// let csc = CscMatrix::from_triplets(
745    ///     2, 2, vec![0, 1], vec![0, 1], vec![1.0, 2.0],
746    /// ).unwrap();
747    /// let dense = csc.to_dense();
748    /// assert_eq!(dense.as_slice(), &[1.0, 0.0, 0.0, 2.0]);
749    /// ```
750    pub fn to_dense(&self) -> Tensor<T> {
751        let mut data = vec![T::zero(); self.nrows * self.ncols];
752        for c in 0..self.ncols {
753            let start = self.col_ptr[c];
754            let end = self.col_ptr[c + 1];
755            for idx in start..end {
756                let r = self.row_idx[idx];
757                data[r * self.ncols + c] = self.values[idx];
758            }
759        }
760        // SAFETY: data has exactly nrows*ncols elements, matching shape [nrows, ncols].
761        Tensor::from_vec(data, vec![self.nrows, self.ncols])
762            .expect("dense data length equals nrows*ncols by construction")
763    }
764
765    /// Sparse matrix × dense vector multiplication.
766    ///
767    /// `x` must be a 1-D tensor of length `ncols`.
768    pub fn matvec(&self, x: &Tensor<T>) -> Result<Tensor<T>> {
769        if x.ndim() != 1 || x.numel() != self.ncols {
770            return Err(CoreError::DimensionMismatch {
771                expected: vec![self.ncols],
772                got: x.shape().to_vec(),
773            });
774        }
775        let xdata = x.as_slice();
776        let mut result = vec![T::zero(); self.nrows];
777
778        for (c, &xc) in xdata.iter().enumerate().take(self.ncols) {
779            let start = self.col_ptr[c];
780            let end = self.col_ptr[c + 1];
781            for idx in start..end {
782                result[self.row_idx[idx]] += self.values[idx] * xc;
783            }
784        }
785
786        Tensor::from_vec(result, vec![self.nrows])
787    }
788
789    /// Transpose, returning a CSR matrix.
790    pub fn transpose(&self) -> CsrMatrix<T> {
791        CsrMatrix {
792            row_ptr: self.col_ptr.clone(),
793            col_idx: self.row_idx.clone(),
794            values: self.values.clone(),
795            nrows: self.ncols,
796            ncols: self.nrows,
797        }
798    }
799
800    /// Convert to COO format.
801    pub fn to_coo(&self) -> CooMatrix<T> {
802        let mut rows = Vec::with_capacity(self.nnz());
803        let mut cols = Vec::with_capacity(self.nnz());
804        let mut values = Vec::with_capacity(self.nnz());
805
806        for c in 0..self.ncols {
807            let start = self.col_ptr[c];
808            let end = self.col_ptr[c + 1];
809            for idx in start..end {
810                rows.push(self.row_idx[idx]);
811                cols.push(c);
812                values.push(self.values[idx]);
813            }
814        }
815
816        CooMatrix {
817            rows,
818            cols,
819            values,
820            nrows: self.nrows,
821            ncols: self.ncols,
822        }
823    }
824
825    /// Convert to CSR format.
826    pub fn to_csr(&self) -> CsrMatrix<T> {
827        self.to_coo().to_csr()
828    }
829
830    /// Sort row indices within each column and sum duplicate entries.
831    fn sort_and_sum_duplicates(&mut self) {
832        for c in 0..self.ncols {
833            let start = self.col_ptr[c];
834            let end = self.col_ptr[c + 1];
835            if start == end {
836                continue;
837            }
838
839            let len = end - start;
840            let mut perm: Vec<usize> = (0..len).collect();
841            perm.sort_unstable_by_key(|&i| self.row_idx[start + i]);
842
843            let old_rows: Vec<usize> = self.row_idx[start..end].to_vec();
844            let old_vals: Vec<T> = self.values[start..end].to_vec();
845            for (j, &p) in perm.iter().enumerate() {
846                self.row_idx[start + j] = old_rows[p];
847                self.values[start + j] = old_vals[p];
848            }
849
850            // Sum duplicates
851            let mut write = start;
852            for read in (start + 1)..end {
853                if self.row_idx[read] == self.row_idx[write] {
854                    let v = self.values[read];
855                    self.values[write] += v;
856                } else {
857                    write += 1;
858                    self.row_idx[write] = self.row_idx[read];
859                    self.values[write] = self.values[read];
860                }
861            }
862            let new_end = write + 1;
863
864            if new_end < end {
865                let removed = end - new_end;
866                let total_idx = self.row_idx.len();
867                self.row_idx.copy_within(end..total_idx, new_end);
868                self.row_idx.truncate(total_idx - removed);
869                let total_vals = self.values.len();
870                self.values.copy_within(end..total_vals, new_end);
871                self.values.truncate(total_vals - removed);
872
873                for i in (c + 1)..=self.ncols {
874                    self.col_ptr[i] -= removed;
875                }
876            }
877        }
878    }
879}
880
881#[cfg(test)]
882#[allow(clippy::float_cmp)]
883mod tests {
884    use super::*;
885
886    // Helper: 3x3 matrix
887    // [[1, 0, 2],
888    //  [0, 3, 0],
889    //  [4, 0, 5]]
890    fn sample_coo() -> CooMatrix<f64> {
891        CooMatrix::from_triplets(
892            3,
893            3,
894            vec![0, 0, 1, 2, 2],
895            vec![0, 2, 1, 0, 2],
896            vec![1.0, 2.0, 3.0, 4.0, 5.0],
897        )
898        .unwrap()
899    }
900
901    #[test]
902    fn test_coo_to_dense() {
903        let coo = sample_coo();
904        let dense = coo.to_dense();
905        assert_eq!(dense.shape(), &[3, 3]);
906        assert_eq!(
907            dense.as_slice(),
908            &[1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0, 5.0]
909        );
910    }
911
912    #[test]
913    fn test_csr_from_dense_roundtrip() {
914        let dense = Tensor::from_vec(
915            vec![1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0, 5.0],
916            vec![3, 3],
917        )
918        .unwrap();
919        let csr = CsrMatrix::from_dense(&dense).unwrap();
920        assert_eq!(csr.nnz(), 5);
921        let back = csr.to_dense();
922        assert_eq!(dense, back);
923    }
924
925    #[test]
926    fn test_csr_matvec() {
927        let csr = sample_coo().to_csr();
928        let x = Tensor::from_vec(vec![1.0, 2.0, 3.0], vec![3]).unwrap();
929        let y = csr.matvec(&x).unwrap();
930        // [1*1 + 0*2 + 2*3, 0*1 + 3*2 + 0*3, 4*1 + 0*2 + 5*3] = [7, 6, 19]
931        assert_eq!(y.as_slice(), &[7.0, 6.0, 19.0]);
932    }
933
934    #[test]
935    fn test_csc_matvec() {
936        let csc = sample_coo().to_csc();
937        let x = Tensor::from_vec(vec![1.0, 2.0, 3.0], vec![3]).unwrap();
938        let y = csc.matvec(&x).unwrap();
939        assert_eq!(y.as_slice(), &[7.0, 6.0, 19.0]);
940    }
941
942    #[test]
943    fn test_coo_csr_csc_dense_roundtrip() {
944        let coo = sample_coo();
945        let expected = coo.to_dense();
946
947        let csr = coo.to_csr();
948        assert_eq!(csr.to_dense(), expected);
949
950        let csc = csr.to_csc();
951        assert_eq!(csc.to_dense(), expected);
952
953        let coo2 = csc.to_coo();
954        assert_eq!(coo2.to_dense(), expected);
955    }
956
957    #[test]
958    fn test_identity_matrix() {
959        let csr = CsrMatrix::from_dense(&Tensor::<f64>::eye(4)).unwrap();
960        assert_eq!(csr.nnz(), 4);
961        let x = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
962        let y = csr.matvec(&x).unwrap();
963        assert_eq!(y, x);
964    }
965
966    #[test]
967    fn test_empty_matrix() {
968        let csr = CsrMatrix::<f64>::new(3, 3);
969        assert_eq!(csr.nnz(), 0);
970        let dense = csr.to_dense();
971        assert_eq!(dense, Tensor::<f64>::zeros(vec![3, 3]));
972    }
973
974    #[test]
975    fn test_dimension_mismatch() {
976        let csr = sample_coo().to_csr();
977        let x = Tensor::from_vec(vec![1.0, 2.0], vec![2]).unwrap();
978        assert!(csr.matvec(&x).is_err());
979    }
980
981    #[test]
982    fn test_duplicate_coo_entries_summed() {
983        // Two entries at (0, 0): 1.0 + 2.0 = 3.0
984        let coo = CooMatrix::from_triplets(2, 2, vec![0, 0, 1], vec![0, 0, 1], vec![1.0, 2.0, 5.0])
985            .unwrap();
986        let csr = coo.to_csr();
987        assert_eq!(*csr.get(0, 0).unwrap(), 3.0);
988        assert_eq!(*csr.get(1, 1).unwrap(), 5.0);
989        assert_eq!(csr.nnz(), 2);
990    }
991
992    #[test]
993    fn test_csr_transpose() {
994        let csr = sample_coo().to_csr();
995        let csc = csr.transpose();
996        // Transposed: nrows/ncols swap
997        assert_eq!(csc.nrows(), 3);
998        assert_eq!(csc.ncols(), 3);
999        // The transposed matrix's dense form should be the transpose of the original
1000        let orig = csr.to_dense();
1001        let trans = csc.to_dense();
1002        // Check (i,j) of trans == (j,i) of orig
1003        for i in 0..3 {
1004            for j in 0..3 {
1005                assert_eq!(*trans.get(&[i, j]).unwrap(), *orig.get(&[j, i]).unwrap());
1006            }
1007        }
1008    }
1009
1010    #[test]
1011    fn test_csr_get() {
1012        let csr = sample_coo().to_csr();
1013        assert_eq!(*csr.get(0, 0).unwrap(), 1.0);
1014        assert_eq!(*csr.get(0, 2).unwrap(), 2.0);
1015        assert!(csr.get(0, 1).is_none()); // zero entry
1016        assert!(csr.get(5, 0).is_none()); // out of bounds
1017    }
1018
1019    #[test]
1020    fn test_coo_push() {
1021        let mut coo = CooMatrix::<f64>::new(2, 2);
1022        coo.push(0, 0, 1.0).unwrap();
1023        coo.push(1, 1, 2.0).unwrap();
1024        assert_eq!(coo.nnz(), 2);
1025        assert!(coo.push(2, 0, 1.0).is_err()); // out of bounds
1026    }
1027}