Skip to main content

scirs2_sparse/formats/
csr5.rs

1//! CSR5 sparse matrix format for balanced SpMV
2//!
3//! CSR5 is a 2D tiling of the CSR format designed for balanced parallel SpMV.
4//! The non-zero elements are partitioned into tiles of configurable width,
5//! and a tile descriptor array enables segmented scan within each tile.
6//!
7//! The key idea: instead of assigning rows to threads (which can be imbalanced),
8//! CSR5 assigns equal-sized tiles of non-zeros to threads and uses the tile
9//! descriptors to handle row boundaries within tiles via segmented reduction.
10//!
11//! # References
12//!
13//! - Liu, W. & Vinter, B. (2015). "CSR5: An Efficient Storage Format for
14//!   Cross-Platform Sparse Matrix-Vector Multiplication." ICS'15.
15
16use crate::csr::CsrMatrix;
17use crate::error::{SparseError, SparseResult};
18use scirs2_core::numeric::{SparseElement, Zero};
19use std::fmt::Debug;
20
21/// Tile descriptor for CSR5 segmented scan.
22///
23/// Each tile descriptor records:
24/// - Whether the tile contains any row boundaries (segment starts).
25/// - For each column within the tile, the row index of the first element.
26/// - Whether each column within the tile starts a new segment.
27#[derive(Debug, Clone)]
28pub struct TileDescriptor {
29    /// Row index that the first element of this tile belongs to.
30    pub first_row: usize,
31    /// Whether this tile has any segment boundaries (row transitions).
32    pub has_segment_boundary: bool,
33    /// Number of complete rows that start within this tile.
34    pub num_complete_rows: usize,
35    /// For each element position in the tile, the row it belongs to.
36    pub row_ids: Vec<usize>,
37    /// Bit-vector: `is_segment_start[i]` is true if position `i` starts a new row.
38    pub is_segment_start: Vec<bool>,
39}
40
41/// CSR5 sparse matrix format.
42///
43/// Tiles the CSR non-zeros into 2D tiles for balanced SpMV.
44/// Each tile contains `tile_width` non-zeros.
45#[derive(Debug, Clone)]
46pub struct Csr5Matrix<T> {
47    /// Number of rows.
48    pub nrows: usize,
49    /// Number of columns.
50    pub ncols: usize,
51    /// Tile width (number of non-zeros per tile).
52    pub tile_width: usize,
53    /// Number of tiles.
54    pub num_tiles: usize,
55    /// Column indices (same as CSR).
56    pub col_indices: Vec<usize>,
57    /// Values (same as CSR).
58    pub values: Vec<T>,
59    /// Row pointers (same as CSR).
60    pub row_ptr: Vec<usize>,
61    /// Tile descriptors for segmented scan.
62    pub tile_desc: Vec<TileDescriptor>,
63    /// Tile pointers: `tile_ptr[t]` = starting offset of tile `t` in col_indices/values.
64    pub tile_ptr: Vec<usize>,
65}
66
67impl<T> Csr5Matrix<T>
68where
69    T: Clone + Copy + Zero + SparseElement + Debug,
70{
71    /// Construct a CSR5 matrix from a CSR matrix.
72    ///
73    /// # Arguments
74    ///
75    /// * `csr` - Source CSR matrix.
76    /// * `tile_width` - Number of non-zeros per tile. Typical values: 16, 32, 64.
77    pub fn from_csr(csr: &CsrMatrix<T>, tile_width: usize) -> SparseResult<Self> {
78        if tile_width == 0 {
79            return Err(SparseError::ValueError(
80                "tile_width must be at least 1".to_string(),
81            ));
82        }
83
84        let (nrows, ncols) = csr.shape();
85        let nnz = csr.nnz();
86
87        // Copy CSR data
88        let col_indices = csr.indices.clone();
89        let values = csr.data.clone();
90        let row_ptr = csr.indptr.clone();
91
92        // Compute number of tiles
93        let num_tiles = if nnz == 0 {
94            0
95        } else {
96            nnz.div_ceil(tile_width)
97        };
98
99        // Build tile pointers
100        let mut tile_ptr = Vec::with_capacity(num_tiles + 1);
101        for t in 0..=num_tiles {
102            tile_ptr.push((t * tile_width).min(nnz));
103        }
104
105        // Build tile descriptors using calibration phase
106        let tile_desc = Self::calibrate(&row_ptr, nrows, nnz, tile_width, num_tiles);
107
108        Ok(Self {
109            nrows,
110            ncols,
111            tile_width,
112            num_tiles,
113            col_indices,
114            values,
115            row_ptr,
116            tile_desc,
117            tile_ptr,
118        })
119    }
120
121    /// Calibration phase: build tile descriptors.
122    ///
123    /// For each tile, determine which rows its elements belong to and
124    /// where segment boundaries (row transitions) occur.
125    fn calibrate(
126        row_ptr: &[usize],
127        nrows: usize,
128        nnz: usize,
129        tile_width: usize,
130        num_tiles: usize,
131    ) -> Vec<TileDescriptor> {
132        let mut descriptors = Vec::with_capacity(num_tiles);
133
134        for t in 0..num_tiles {
135            let tile_start = t * tile_width;
136            let tile_end = nnz.min(tile_start + tile_width);
137            let tile_len = tile_end - tile_start;
138
139            // Find which row the first element belongs to
140            let first_row = Self::find_row(row_ptr, nrows, tile_start);
141
142            // Build row IDs and segment start flags
143            let mut row_ids = Vec::with_capacity(tile_len);
144            let mut is_segment_start = Vec::with_capacity(tile_len);
145            let mut current_row = first_row;
146            let mut num_complete_rows = 0usize;
147            let mut has_boundary = false;
148
149            for pos in tile_start..tile_end {
150                // Advance current_row until row_ptr[current_row + 1] > pos
151                while current_row < nrows && row_ptr[current_row + 1] <= pos {
152                    current_row += 1;
153                }
154
155                let is_start = if pos == tile_start {
156                    // First element in tile: it's a segment start if it's also
157                    // the first element of its row
158                    pos == row_ptr[current_row]
159                } else {
160                    pos == row_ptr[current_row]
161                };
162
163                if is_start && pos != tile_start {
164                    has_boundary = true;
165                    num_complete_rows += 1;
166                }
167
168                row_ids.push(current_row);
169                is_segment_start.push(is_start);
170            }
171
172            descriptors.push(TileDescriptor {
173                first_row,
174                has_segment_boundary: has_boundary,
175                num_complete_rows,
176                row_ids,
177                is_segment_start,
178            });
179        }
180
181        descriptors
182    }
183
184    /// Binary search to find which row a given NNZ position belongs to.
185    fn find_row(row_ptr: &[usize], nrows: usize, pos: usize) -> usize {
186        // row_ptr[row] <= pos < row_ptr[row+1]
187        let mut lo = 0usize;
188        let mut hi = nrows;
189        while lo < hi {
190            let mid = lo + (hi - lo) / 2;
191            if row_ptr[mid + 1] <= pos {
192                lo = mid + 1;
193            } else {
194                hi = mid;
195            }
196        }
197        lo
198    }
199
200    /// Perform SpMV: `y = self * x`.
201    ///
202    /// Two-phase SpMV:
203    /// 1. Each tile computes partial sums using segmented reduction.
204    /// 2. Partial sums for rows spanning multiple tiles are merged.
205    pub fn spmv(&self, x: &[T]) -> SparseResult<Vec<T>>
206    where
207        T: std::ops::Add<Output = T> + std::ops::Mul<Output = T>,
208    {
209        if x.len() != self.ncols {
210            return Err(SparseError::DimensionMismatch {
211                expected: self.ncols,
212                found: x.len(),
213            });
214        }
215
216        let mut y = vec![T::sparse_zero(); self.nrows];
217
218        if self.num_tiles == 0 {
219            return Ok(y);
220        }
221
222        // Phase 1: per-tile segmented reduction
223        // For each tile, accumulate partial sums and write completed rows to y.
224        // Carry-over partial sums are collected for cross-tile merging.
225
226        // carry[t] = (row, partial_sum) for the last segment in tile t
227        // that may continue into the next tile.
228        let mut carries: Vec<Option<(usize, T)>> = vec![None; self.num_tiles];
229
230        for t in 0..self.num_tiles {
231            let desc = &self.tile_desc[t];
232            let tile_start = self.tile_ptr[t];
233            let tile_end = self.tile_ptr[t + 1];
234            let tile_len = tile_end - tile_start;
235
236            if tile_len == 0 {
237                continue;
238            }
239
240            // Segmented scan within the tile
241            let mut acc = T::sparse_zero();
242            let mut current_row = desc.first_row;
243
244            for i in 0..tile_len {
245                let pos = tile_start + i;
246                let row = desc.row_ids[i];
247
248                if row != current_row {
249                    // Row boundary: flush acc to y or carry
250                    if i == 0 {
251                        // The very first position already switched — means
252                        // previous tile's carry goes to current_row
253                    } else {
254                        // current_row's segment is complete within this tile
255                        y[current_row] = y[current_row] + acc;
256                    }
257                    acc = T::sparse_zero();
258                    current_row = row;
259                }
260
261                acc = acc + self.values[pos] * x[self.col_indices[pos]];
262            }
263
264            // Remaining acc is carry for this tile's last row
265            carries[t] = Some((current_row, acc));
266        }
267
268        // Phase 2: merge carries
269        // Process carries from the first tile forward.
270        // If tile t's first row equals tile (t-1)'s carry row, accumulate.
271        // Otherwise, flush the previous carry.
272
273        for t in 0..self.num_tiles {
274            if let Some((row, val)) = carries[t] {
275                // Check if the next tile continues this row
276                let continues = if t + 1 < self.num_tiles {
277                    let next_desc = &self.tile_desc[t + 1];
278                    next_desc.first_row == row
279                } else {
280                    false
281                };
282
283                if continues {
284                    // Propagate carry to the next tile
285                    if let Some((_, ref mut next_val)) = carries[t + 1] {
286                        // The next tile's first partial sum needs this carry added
287                        // But the next tile's first_row = row, so the next carry
288                        // already includes partial sums. We add to y and let next
289                        // tile's carry handle the rest.
290                        y[row] = y[row] + val;
291                    } else {
292                        y[row] = y[row] + val;
293                    }
294                } else {
295                    y[row] = y[row] + val;
296                }
297            }
298        }
299
300        Ok(y)
301    }
302
303    /// Convert back to CSR format.
304    pub fn to_csr(&self) -> SparseResult<CsrMatrix<T>>
305    where
306        T: std::cmp::PartialEq,
307    {
308        // The CSR data is already stored internally; just reconstruct triplets.
309        let mut row_indices: Vec<usize> = Vec::with_capacity(self.values.len());
310        let mut col_indices: Vec<usize> = Vec::with_capacity(self.values.len());
311        let mut data: Vec<T> = Vec::with_capacity(self.values.len());
312
313        for row in 0..self.nrows {
314            let start = self.row_ptr[row];
315            let end = self.row_ptr[row + 1];
316            for pos in start..end {
317                row_indices.push(row);
318                col_indices.push(self.col_indices[pos]);
319                data.push(self.values[pos]);
320            }
321        }
322
323        CsrMatrix::new(data, row_indices, col_indices, (self.nrows, self.ncols))
324    }
325
326    /// Number of non-zeros.
327    pub fn nnz(&self) -> usize {
328        self.values.len()
329    }
330
331    /// Return the tile width.
332    pub fn get_tile_width(&self) -> usize {
333        self.tile_width
334    }
335
336    /// Return the number of tiles.
337    pub fn get_num_tiles(&self) -> usize {
338        self.num_tiles
339    }
340}
341
342#[cfg(test)]
343mod tests {
344    use super::*;
345    use approx::assert_relative_eq;
346
347    fn make_tridiag_csr(n: usize) -> CsrMatrix<f64> {
348        let mut rows = Vec::new();
349        let mut cols = Vec::new();
350        let mut vals = Vec::new();
351        for i in 0..n {
352            rows.push(i);
353            cols.push(i);
354            vals.push(2.0);
355            if i > 0 {
356                rows.push(i);
357                cols.push(i - 1);
358                vals.push(-1.0);
359            }
360            if i + 1 < n {
361                rows.push(i);
362                cols.push(i + 1);
363                vals.push(-1.0);
364            }
365        }
366        CsrMatrix::new(vals, rows, cols, (n, n)).expect("csr")
367    }
368
369    fn csr_spmv(csr: &CsrMatrix<f64>, x: &[f64]) -> Vec<f64> {
370        let (nrows, _) = csr.shape();
371        let mut y = vec![0.0f64; nrows];
372        for row in 0..nrows {
373            for j in csr.indptr[row]..csr.indptr[row + 1] {
374                y[row] += csr.data[j] * x[csr.indices[j]];
375            }
376        }
377        y
378    }
379
380    #[test]
381    fn test_csr5_spmv_matches_csr() {
382        let csr = make_tridiag_csr(8);
383        let x: Vec<f64> = (0..8).map(|i| (i + 1) as f64).collect();
384        let y_ref = csr_spmv(&csr, &x);
385
386        for &tw in &[4usize, 8, 16, 32] {
387            let csr5 = Csr5Matrix::from_csr(&csr, tw).expect("csr5");
388            let y_csr5 = csr5.spmv(&x).expect("spmv");
389            for i in 0..8 {
390                assert_relative_eq!(y_csr5[i], y_ref[i], epsilon = 1e-12);
391            }
392        }
393    }
394
395    #[test]
396    fn test_csr5_preserves_nnz() {
397        let csr = make_tridiag_csr(10);
398        let csr5 = Csr5Matrix::from_csr(&csr, 4).expect("csr5");
399        assert_eq!(csr5.nnz(), csr.nnz());
400    }
401
402    #[test]
403    fn test_csr5_roundtrip() {
404        let csr = make_tridiag_csr(6);
405        let csr5 = Csr5Matrix::from_csr(&csr, 4).expect("csr5");
406        let csr2 = csr5.to_csr().expect("to_csr");
407        assert_eq!(csr2.nnz(), csr.nnz());
408        let x: Vec<f64> = (0..6).map(|i| (i + 1) as f64).collect();
409        let y1 = csr_spmv(&csr, &x);
410        let y2 = csr_spmv(&csr2, &x);
411        for i in 0..6 {
412            assert_relative_eq!(y1[i], y2[i], epsilon = 1e-12);
413        }
414    }
415
416    #[test]
417    fn test_csr5_irregular_matrix() {
418        // Matrix with varying row lengths
419        let rows = vec![0, 0, 0, 0, 1, 2, 2, 3, 3, 3];
420        let cols = vec![0, 1, 2, 3, 0, 0, 3, 1, 2, 3];
421        let vals = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
422        let csr = CsrMatrix::new(vals, rows, cols, (4, 4)).expect("csr");
423
424        let x = vec![1.0, 2.0, 3.0, 4.0];
425        let y_ref = csr_spmv(&csr, &x);
426
427        let csr5 = Csr5Matrix::from_csr(&csr, 3).expect("csr5");
428        let y_csr5 = csr5.spmv(&x).expect("spmv");
429
430        for i in 0..4 {
431            assert_relative_eq!(y_csr5[i], y_ref[i], epsilon = 1e-12);
432        }
433    }
434
435    #[test]
436    fn test_csr5_empty_matrix() {
437        let csr = CsrMatrix::<f64>::new(vec![], vec![], vec![], (3, 3)).expect("csr");
438        let csr5 = Csr5Matrix::from_csr(&csr, 4).expect("csr5");
439        assert_eq!(csr5.nnz(), 0);
440        assert_eq!(csr5.num_tiles, 0);
441        let y = csr5.spmv(&[0.0, 0.0, 0.0]).expect("spmv");
442        assert_eq!(y, vec![0.0, 0.0, 0.0]);
443    }
444
445    #[test]
446    fn test_csr5_tile_width_error() {
447        let csr = make_tridiag_csr(4);
448        assert!(Csr5Matrix::<f64>::from_csr(&csr, 0).is_err());
449    }
450
451    #[test]
452    fn test_csr5_single_row() {
453        let csr =
454            CsrMatrix::new(vec![1.0, 2.0, 3.0], vec![0, 0, 0], vec![0, 1, 2], (1, 3)).expect("csr");
455        let x = vec![1.0, 2.0, 3.0];
456        let y_ref = csr_spmv(&csr, &x);
457        let csr5 = Csr5Matrix::from_csr(&csr, 2).expect("csr5");
458        let y = csr5.spmv(&x).expect("spmv");
459        assert_relative_eq!(y[0], y_ref[0], epsilon = 1e-12);
460    }
461
462    #[test]
463    fn test_csr5_large_tile() {
464        // Tile larger than nnz
465        let csr = make_tridiag_csr(4);
466        let x: Vec<f64> = (0..4).map(|i| (i + 1) as f64).collect();
467        let y_ref = csr_spmv(&csr, &x);
468        let csr5 = Csr5Matrix::from_csr(&csr, 100).expect("csr5");
469        let y = csr5.spmv(&x).expect("spmv");
470        for i in 0..4 {
471            assert_relative_eq!(y[i], y_ref[i], epsilon = 1e-12);
472        }
473    }
474}