Skip to main content

trueno_sparse/
sell.rs

1//! Sliced ELLPACK (SELL) sparse matrix format.
2//!
3//! # Contract: sparse-formats-v1.yaml
4//!
5//! SELL-C-σ format: rows sorted by length within slices of C rows.
6//! Each slice is padded to the max row length in that slice.
7//! This gives SIMD-friendly contiguous access patterns.
8//!
9//! ## References
10//! - Kreutzer et al., "A unified sparse matrix data format for modern processors", 2014
11
12use crate::csr::CsrMatrix;
13use crate::error::SparseError;
14
15/// Sliced ELLPACK sparse matrix.
16///
17/// Rows are grouped into slices of `slice_size` rows. Within each slice,
18/// columns and values are stored in column-major order, padded to the
19/// max row length in that slice.
20#[derive(Debug, Clone)]
21pub struct SellMatrix {
22    rows: usize,
23    cols: usize,
24    slice_size: usize,
25    /// Number of slices = ceil(rows / slice_size).
26    num_slices: usize,
27    /// Offset into col_indices/values for each slice (len = num_slices + 1).
28    slice_offsets: Vec<u32>,
29    /// Max row length in each slice (len = num_slices).
30    slice_widths: Vec<u32>,
31    /// Column indices (padded, column-major within each slice).
32    col_indices: Vec<u32>,
33    /// Values (padded, column-major within each slice).
34    values: Vec<f32>,
35}
36
37impl SellMatrix {
38    /// Convert a CSR matrix to SELL format with the given slice size.
39    ///
40    /// Typical slice_size: 32 or 64 (matching SIMD width or warp size).
41    #[must_use]
42    pub fn from_csr(csr: &CsrMatrix<f32>, slice_size: usize) -> Self {
43        let rows = csr.rows();
44        let cols = csr.cols();
45        let c = if slice_size == 0 { 1 } else { slice_size };
46        let num_slices = rows.div_ceil(c);
47
48        let mut slice_offsets = Vec::with_capacity(num_slices + 1);
49        let mut slice_widths = Vec::with_capacity(num_slices);
50        let mut col_indices = Vec::new();
51        let mut values = Vec::new();
52
53        slice_offsets.push(0u32);
54
55        for s in 0..num_slices {
56            let row_start = s * c;
57            let row_end = (row_start + c).min(rows);
58            let actual_rows = row_end - row_start;
59
60            // Find max row length in this slice
61            let max_len = compute_slice_width(csr, row_start, row_end);
62            slice_widths.push(max_len as u32);
63
64            // Store in column-major order within the slice
65            fill_slice_data(csr, row_start, actual_rows, c, max_len, &mut col_indices, &mut values);
66
67            let slice_elements = c * max_len;
68            let offset = slice_offsets.last().copied().unwrap_or(0);
69            slice_offsets.push(offset + slice_elements as u32);
70        }
71
72        Self {
73            rows,
74            cols,
75            slice_size: c,
76            num_slices,
77            slice_offsets,
78            slice_widths,
79            col_indices,
80            values,
81        }
82    }
83
84    /// SpMV: y = α·A·x + β·y
85    ///
86    /// # Errors
87    ///
88    /// Returns error on dimension mismatch.
89    pub fn spmv(&self, alpha: f32, x: &[f32], beta: f32, y: &mut [f32]) -> Result<(), SparseError> {
90        if x.len() != self.cols {
91            return Err(SparseError::SpMVDimensionMismatch {
92                matrix_cols: self.cols,
93                x_len: x.len(),
94            });
95        }
96        if y.len() != self.rows {
97            return Err(SparseError::SpMVOutputDimensionMismatch {
98                matrix_rows: self.rows,
99                y_len: y.len(),
100            });
101        }
102
103        // Scale y by beta
104        for val in y.iter_mut() {
105            *val *= beta;
106        }
107
108        let c = self.slice_size;
109
110        for s in 0..self.num_slices {
111            let base = self.slice_offsets[s] as usize;
112            let width = self.slice_widths[s] as usize;
113            let row_start = s * c;
114            let row_end = (row_start + c).min(self.rows);
115
116            spmv_slice(
117                &self.col_indices,
118                &self.values,
119                x,
120                y,
121                alpha,
122                base,
123                c,
124                width,
125                row_start,
126                row_end,
127            );
128        }
129
130        Ok(())
131    }
132
133    /// Number of rows.
134    #[must_use]
135    pub fn rows(&self) -> usize {
136        self.rows
137    }
138
139    /// Number of columns.
140    #[must_use]
141    pub fn cols(&self) -> usize {
142        self.cols
143    }
144
145    /// Slice size (C parameter).
146    #[must_use]
147    pub fn slice_size(&self) -> usize {
148        self.slice_size
149    }
150
151    /// Total stored elements (including padding zeros).
152    #[must_use]
153    pub fn storage_size(&self) -> usize {
154        self.values.len()
155    }
156}
157
158/// Compute max row length in a slice.
159fn compute_slice_width(csr: &CsrMatrix<f32>, row_start: usize, row_end: usize) -> usize {
160    let offsets = csr.offsets();
161    let mut max_len = 0usize;
162    for r in row_start..row_end {
163        let len = (offsets[r + 1] - offsets[r]) as usize;
164        if len > max_len {
165            max_len = len;
166        }
167    }
168    max_len
169}
170
171/// Fill column-major data for one slice.
172fn fill_slice_data(
173    csr: &CsrMatrix<f32>,
174    row_start: usize,
175    actual_rows: usize,
176    c: usize,
177    max_len: usize,
178    col_indices: &mut Vec<u32>,
179    values: &mut Vec<f32>,
180) {
181    let csr_off = csr.offsets();
182    let csr_cols = csr.col_indices();
183    let csr_vals = csr.values();
184
185    // Column-major: for each column position j, store all rows
186    for j in 0..max_len {
187        for local_r in 0..c {
188            let global_r = row_start + local_r;
189            if local_r < actual_rows {
190                let row_start_idx = csr_off[global_r] as usize;
191                let row_len = (csr_off[global_r + 1] - csr_off[global_r]) as usize;
192                if j < row_len {
193                    col_indices.push(csr_cols[row_start_idx + j]);
194                    values.push(csr_vals[row_start_idx + j]);
195                } else {
196                    col_indices.push(0);
197                    values.push(0.0);
198                }
199            } else {
200                // Padding rows (beyond actual matrix rows)
201                col_indices.push(0);
202                values.push(0.0);
203            }
204        }
205    }
206}
207
208/// SpMV for one SELL slice.
209#[allow(clippy::too_many_arguments)]
210fn spmv_slice(
211    col_indices: &[u32],
212    values: &[f32],
213    x: &[f32],
214    y: &mut [f32],
215    alpha: f32,
216    base: usize,
217    c: usize,
218    width: usize,
219    row_start: usize,
220    row_end: usize,
221) {
222    for j in 0..width {
223        for local_r in 0..(row_end - row_start) {
224            let idx = base + j * c + local_r;
225            let col = col_indices[idx] as usize;
226            let val = values[idx];
227            y[row_start + local_r] += alpha * val * x[col];
228        }
229    }
230}