Skip to main content

trueno_sparse/
csr.rs

1//! Compressed Sparse Row (CSR) matrix format.
2//!
3//! CSR is the primary format for sparse matrix arithmetic (SpMV, SpMM).
4//! Stores row offsets, column indices, and values in three contiguous arrays.
5//!
6//! # Contract: sparse-spmv-v1.yaml
7//!
8//! Format invariants validated at construction:
9//! - `offsets[0] == 0 && offsets[rows] == nnz`
10//! - `offsets` monotonically non-decreasing
11//! - All column indices in `[0, cols)`
12
13use crate::coo::CooMatrix;
14use crate::error::SparseError;
15use crate::validate::validate_csr_invariants;
16
17/// Compressed Sparse Row matrix.
18///
19/// Three-array representation: `offsets` (len = rows+1), `col_indices` (len = nnz),
20/// `values` (len = nnz). Row `i` has nonzeros at positions `offsets[i]..offsets[i+1]`.
21///
22/// All invariants are validated at construction time (provable contract).
23#[derive(Debug, Clone)]
24pub struct CsrMatrix<T> {
25    rows: usize,
26    cols: usize,
27    offsets: Vec<u32>,
28    col_indices: Vec<u32>,
29    values: Vec<T>,
30}
31
32impl<T: Clone + Default> CsrMatrix<T> {
33    /// Create a CSR matrix with full validation.
34    ///
35    /// # Contract: sparse-spmv-v1.yaml / format_validation
36    ///
37    /// Validates all CSR invariants before construction.
38    ///
39    /// # Errors
40    ///
41    /// Returns error if any CSR invariant is violated.
42    pub fn new(
43        rows: usize,
44        cols: usize,
45        offsets: Vec<u32>,
46        col_indices: Vec<u32>,
47        values: Vec<T>,
48    ) -> Result<Self, SparseError> {
49        validate_csr_invariants(rows, cols, &offsets, &col_indices, values.len())?;
50        Ok(Self { rows, cols, offsets, col_indices, values })
51    }
52
53    /// Convert from COO format to CSR.
54    ///
55    /// Sorts triplets by row, then by column within each row.
56    /// Duplicate entries are summed (standard convention).
57    #[must_use]
58    pub fn from_coo(coo: &CooMatrix<T>) -> Self
59    where
60        T: std::ops::AddAssign + Copy,
61    {
62        let rows = coo.rows;
63        let cols = coo.cols;
64        let nnz = coo.nnz();
65
66        if nnz == 0 {
67            return Self {
68                rows,
69                cols,
70                offsets: vec![0; rows + 1],
71                col_indices: Vec::new(),
72                values: Vec::new(),
73            };
74        }
75
76        // Count nonzeros per row
77        let mut row_counts = vec![0u32; rows];
78        for &r in &coo.row_indices {
79            row_counts[r as usize] += 1;
80        }
81
82        // Build offsets via prefix sum
83        let mut offsets = vec![0u32; rows + 1];
84        for i in 0..rows {
85            offsets[i + 1] = offsets[i] + row_counts[i];
86        }
87
88        // Fill col_indices and values (sort by row)
89        let mut col_indices = vec![0u32; nnz];
90        let mut values = vec![T::default(); nnz];
91        let mut write_pos = offsets.clone();
92
93        for idx in 0..nnz {
94            let r = coo.row_indices[idx] as usize;
95            let pos = write_pos[r] as usize;
96            col_indices[pos] = coo.col_indices[idx];
97            values[pos] = coo.values[idx];
98            write_pos[r] += 1;
99        }
100
101        // Sort columns within each row
102        for i in 0..rows {
103            let start = offsets[i] as usize;
104            let end = offsets[i + 1] as usize;
105            if end - start > 1 {
106                // Simple insertion sort (rows are typically short)
107                for j in (start + 1)..end {
108                    let mut k = j;
109                    while k > start && col_indices[k - 1] > col_indices[k] {
110                        col_indices.swap(k - 1, k);
111                        values.swap(k - 1, k);
112                        k -= 1;
113                    }
114                }
115            }
116        }
117
118        Self { rows, cols, offsets, col_indices, values }
119    }
120
121    /// Create an identity matrix of size n.
122    #[must_use]
123    pub fn identity(n: usize) -> Self
124    where
125        T: From<f32>,
126    {
127        let offsets: Vec<u32> = (0..=n).map(|i| i as u32).collect();
128        let col_indices: Vec<u32> = (0..n).map(|i| i as u32).collect();
129        let values: Vec<T> = (0..n).map(|_| T::from(1.0)).collect();
130        Self { rows: n, cols: n, offsets, col_indices, values }
131    }
132
133    /// Number of rows.
134    #[must_use]
135    pub fn rows(&self) -> usize {
136        self.rows
137    }
138
139    /// Number of columns.
140    #[must_use]
141    pub fn cols(&self) -> usize {
142        self.cols
143    }
144
145    /// Number of stored nonzero entries.
146    #[must_use]
147    pub fn nnz(&self) -> usize {
148        self.values.len()
149    }
150
151    /// Row offsets array (len = rows + 1).
152    #[must_use]
153    pub fn offsets(&self) -> &[u32] {
154        &self.offsets
155    }
156
157    /// Column indices array (len = nnz).
158    #[must_use]
159    pub fn col_indices(&self) -> &[u32] {
160        &self.col_indices
161    }
162
163    /// Values array (len = nnz).
164    #[must_use]
165    pub fn values(&self) -> &[T] {
166        &self.values
167    }
168
169    /// Average number of nonzeros per row.
170    #[must_use]
171    #[allow(clippy::cast_precision_loss)]
172    pub fn avg_nnz_per_row(&self) -> f64 {
173        if self.rows == 0 {
174            0.0
175        } else {
176            self.nnz() as f64 / self.rows as f64
177        }
178    }
179
180    /// Variance of row lengths (key metric for algorithm selection).
181    ///
182    /// High variance → merge-based SpMV; low variance → row-split SpMV.
183    #[must_use]
184    #[allow(clippy::cast_precision_loss)]
185    pub fn row_length_variance(&self) -> f64 {
186        if self.rows == 0 {
187            return 0.0;
188        }
189        let mean = self.avg_nnz_per_row();
190        let sum_sq: f64 = (0..self.rows)
191            .map(|i| {
192                let len = f64::from(self.offsets[i + 1] - self.offsets[i]);
193                (len - mean) * (len - mean)
194            })
195            .sum();
196        sum_sq / self.rows as f64
197    }
198
199    /// Convert to dense matrix (row-major).
200    #[must_use]
201    pub fn to_dense(&self) -> Vec<T>
202    where
203        T: Copy + std::ops::AddAssign,
204    {
205        let mut dense = vec![T::default(); self.rows * self.cols];
206        for i in 0..self.rows {
207            let start = self.offsets[i] as usize;
208            let end = self.offsets[i + 1] as usize;
209            for idx in start..end {
210                let j = self.col_indices[idx] as usize;
211                dense[i * self.cols + j] += self.values[idx];
212            }
213        }
214        dense
215    }
216}