Skip to main content

tang_sparse/
coo.rs

1use crate::CsrMatrix;
2use alloc::vec::Vec;
3use tang::Scalar;
4
5/// Coordinate (triplet) format for assembling sparse matrices.
6///
7/// Good for incremental construction, then convert to CSR for computation.
8pub struct CooMatrix<S> {
9    pub nrows: usize,
10    pub ncols: usize,
11    pub rows: Vec<usize>,
12    pub cols: Vec<usize>,
13    pub vals: Vec<S>,
14}
15
16impl<S: Scalar> CooMatrix<S> {
17    pub fn new(nrows: usize, ncols: usize) -> Self {
18        Self {
19            nrows,
20            ncols,
21            rows: Vec::new(),
22            cols: Vec::new(),
23            vals: Vec::new(),
24        }
25    }
26
27    pub fn with_capacity(nrows: usize, ncols: usize, nnz: usize) -> Self {
28        Self {
29            nrows,
30            ncols,
31            rows: Vec::with_capacity(nnz),
32            cols: Vec::with_capacity(nnz),
33            vals: Vec::with_capacity(nnz),
34        }
35    }
36
37    /// Add a triplet (row, col, value). Duplicates are summed during conversion.
38    pub fn push(&mut self, row: usize, col: usize, val: S) {
39        assert!(row < self.nrows && col < self.ncols);
40        self.rows.push(row);
41        self.cols.push(col);
42        self.vals.push(val);
43    }
44
45    pub fn nnz(&self) -> usize {
46        self.rows.len()
47    }
48
49    /// Convert to CSR format (summing duplicate entries).
50    pub fn to_csr(&self) -> CsrMatrix<S> {
51        let mut row_counts = alloc::vec![0usize; self.nrows + 1];
52        for &r in &self.rows {
53            row_counts[r + 1] += 1;
54        }
55        // Prefix sum
56        for i in 1..=self.nrows {
57            row_counts[i] += row_counts[i - 1];
58        }
59        let nnz = row_counts[self.nrows];
60        let mut col_indices = alloc::vec![0usize; nnz];
61        let mut values = alloc::vec![S::ZERO; nnz];
62        let mut offsets = row_counts.clone();
63
64        for k in 0..self.rows.len() {
65            let r = self.rows[k];
66            let pos = offsets[r];
67            col_indices[pos] = self.cols[k];
68            values[pos] = self.vals[k];
69            offsets[r] += 1;
70        }
71
72        // Sort each row by column and merge duplicates
73        let row_ptrs = row_counts;
74        for i in 0..self.nrows {
75            let start = row_ptrs[i];
76            let end = row_ptrs[i + 1];
77            // Insertion sort (rows are typically short)
78            for j in (start + 1)..end {
79                let mut k = j;
80                while k > start && col_indices[k] < col_indices[k - 1] {
81                    col_indices.swap(k, k - 1);
82                    values.swap(k, k - 1);
83                    k -= 1;
84                }
85            }
86        }
87
88        // Merge duplicates
89        let mut new_col = Vec::with_capacity(nnz);
90        let mut new_val = Vec::with_capacity(nnz);
91        let mut new_ptrs = alloc::vec![0usize; self.nrows + 1];
92
93        for i in 0..self.nrows {
94            let start = row_ptrs[i];
95            let end = row_ptrs[i + 1];
96            let mut j = start;
97            while j < end {
98                let c = col_indices[j];
99                let mut v = values[j];
100                j += 1;
101                while j < end && col_indices[j] == c {
102                    v += values[j];
103                    j += 1;
104                }
105                new_col.push(c);
106                new_val.push(v);
107            }
108            new_ptrs[i + 1] = new_col.len();
109        }
110
111        CsrMatrix {
112            nrows: self.nrows,
113            ncols: self.ncols,
114            row_ptrs: new_ptrs,
115            col_indices: new_col,
116            values: new_val,
117        }
118    }
119}