Skip to main content

trueno_sparse/
bsr.rs

1//! Block Sparse Row (BSR) format.
2//!
3//! Stores sparse matrices as blocks of dense sub-matrices, aligned on a
4//! regular block grid. Efficient for FEM and structured sparsity patterns.
5
6use crate::csr::CsrMatrix;
7use crate::error::SparseError;
8use crate::ops::SparseOps;
9
10/// Block Sparse Row matrix.
11///
12/// A matrix of shape `(block_rows * block_size) × (block_cols * block_size)`,
13/// where non-zero blocks are stored in CSR-of-blocks layout.
14#[derive(Debug, Clone)]
15pub struct BsrMatrix {
16    /// Number of block rows.
17    block_rows: usize,
18    /// Number of block columns.
19    block_cols: usize,
20    /// Block dimension (blocks are block_size × block_size).
21    block_size: usize,
22    /// Row offsets for block CSR (length = block_rows + 1).
23    offsets: Vec<u32>,
24    /// Block column indices.
25    col_indices: Vec<u32>,
26    /// Dense block values, stored row-major per block.
27    /// Length = nnz_blocks * block_size * block_size.
28    values: Vec<f32>,
29}
30
31impl BsrMatrix {
32    /// Create a new BSR matrix.
33    ///
34    /// # Arguments
35    ///
36    /// - `block_rows`, `block_cols`: number of block rows/columns
37    /// - `block_size`: dimension of each square block
38    /// - `offsets`: CSR-style row offsets for blocks
39    /// - `col_indices`: block column indices
40    /// - `values`: dense block data (row-major per block)
41    ///
42    /// # Errors
43    ///
44    /// Returns error if structure is invalid.
45    pub fn new(
46        block_rows: usize,
47        block_cols: usize,
48        block_size: usize,
49        offsets: Vec<u32>,
50        col_indices: Vec<u32>,
51        values: Vec<f32>,
52    ) -> Result<Self, SparseError> {
53        if offsets.len() != block_rows + 1 {
54            return Err(SparseError::InvalidOffsetsLength {
55                actual: offsets.len(),
56                expected: block_rows + 1,
57            });
58        }
59        let nnz_blocks = col_indices.len();
60        let expected_vals = nnz_blocks * block_size * block_size;
61        if values.len() != expected_vals {
62            return Err(SparseError::LengthMismatch {
63                col_len: expected_vals,
64                val_len: values.len(),
65            });
66        }
67        Ok(Self { block_rows, block_cols, block_size, offsets, col_indices, values })
68    }
69
70    /// Create BSR from a dense matrix.
71    ///
72    /// Pads the matrix if dimensions aren't divisible by block_size.
73    /// Only stores blocks with at least one non-zero element.
74    pub fn from_dense(data: &[f32], rows: usize, cols: usize, block_size: usize) -> Self {
75        let br = rows.div_ceil(block_size);
76        let bc = cols.div_ceil(block_size);
77
78        let mut offsets = vec![0u32; br + 1];
79        let mut col_indices = Vec::new();
80        let mut values = Vec::new();
81        let bs2 = block_size * block_size;
82
83        for bi in 0..br {
84            for bj in 0..bc {
85                let mut block = vec![0.0f32; bs2];
86                let mut has_nonzero = false;
87                for li in 0..block_size {
88                    for lj in 0..block_size {
89                        let gi = bi * block_size + li;
90                        let gj = bj * block_size + lj;
91                        if gi < rows && gj < cols {
92                            let val = data[gi * cols + gj];
93                            block[li * block_size + lj] = val;
94                            if val != 0.0 {
95                                has_nonzero = true;
96                            }
97                        }
98                    }
99                }
100                if has_nonzero {
101                    col_indices.push(bj as u32);
102                    values.extend_from_slice(&block);
103                }
104            }
105            offsets[bi + 1] = col_indices.len() as u32;
106        }
107
108        Self { block_rows: br, block_cols: bc, block_size, offsets, col_indices, values }
109    }
110
111    /// Convert to CSR format.
112    ///
113    /// # Errors
114    ///
115    /// Returns error if the internal conversion produces invalid CSR.
116    pub fn to_csr(&self) -> Result<CsrMatrix<f32>, SparseError> {
117        let rows = self.block_rows * self.block_size;
118        let cols = self.block_cols * self.block_size;
119        let bs = self.block_size;
120        let bs2 = bs * bs;
121
122        let mut csr_offsets = vec![0u32; rows + 1];
123        let mut csr_cols = Vec::new();
124        let mut csr_vals = Vec::new();
125
126        for bi in 0..self.block_rows {
127            let blk_start = self.offsets[bi] as usize;
128            let blk_end = self.offsets[bi + 1] as usize;
129
130            for li in 0..bs {
131                let global_row = bi * bs + li;
132                if global_row >= rows {
133                    break;
134                }
135                for blk_idx in blk_start..blk_end {
136                    let bj = self.col_indices[blk_idx] as usize;
137                    for lj in 0..bs {
138                        let global_col = bj * bs + lj;
139                        if global_col >= cols {
140                            continue;
141                        }
142                        let val = self.values[blk_idx * bs2 + li * bs + lj];
143                        if val != 0.0 {
144                            csr_cols.push(global_col as u32);
145                            csr_vals.push(val);
146                        }
147                    }
148                }
149                csr_offsets[global_row + 1] = csr_cols.len() as u32;
150            }
151        }
152
153        CsrMatrix::new(rows, cols, csr_offsets, csr_cols, csr_vals)
154    }
155
156    /// Total matrix rows.
157    pub fn rows(&self) -> usize {
158        self.block_rows * self.block_size
159    }
160
161    /// Total matrix columns.
162    pub fn cols(&self) -> usize {
163        self.block_cols * self.block_size
164    }
165
166    /// Number of non-zero blocks.
167    pub fn nnz_blocks(&self) -> usize {
168        self.col_indices.len()
169    }
170
171    /// Block size.
172    pub fn block_size(&self) -> usize {
173        self.block_size
174    }
175}
176
177impl SparseOps for BsrMatrix {
178    fn spmv(&self, alpha: f32, x: &[f32], beta: f32, y: &mut [f32]) -> Result<(), SparseError> {
179        if x.len() != self.cols() {
180            return Err(SparseError::SpMVDimensionMismatch {
181                matrix_cols: self.cols(),
182                x_len: x.len(),
183            });
184        }
185        if y.len() != self.rows() {
186            return Err(SparseError::SpMVOutputDimensionMismatch {
187                matrix_rows: self.rows(),
188                y_len: y.len(),
189            });
190        }
191
192        let bs = self.block_size;
193        let bs2 = bs * bs;
194
195        // y = beta * y
196        for yi in y.iter_mut() {
197            *yi *= beta;
198        }
199
200        // y += alpha * A * x
201        for bi in 0..self.block_rows {
202            let blk_start = self.offsets[bi] as usize;
203            let blk_end = self.offsets[bi + 1] as usize;
204
205            for blk_idx in blk_start..blk_end {
206                let bj = self.col_indices[blk_idx] as usize;
207                let block = &self.values[blk_idx * bs2..(blk_idx + 1) * bs2];
208
209                for li in 0..bs {
210                    let gi = bi * bs + li;
211                    if gi >= y.len() {
212                        break;
213                    }
214                    let mut sum = 0.0f32;
215                    for lj in 0..bs {
216                        let gj = bj * bs + lj;
217                        if gj < x.len() {
218                            sum += block[li * bs + lj] * x[gj];
219                        }
220                    }
221                    y[gi] += alpha * sum;
222                }
223            }
224        }
225
226        Ok(())
227    }
228
229    fn spmm(
230        &self,
231        alpha: f32,
232        b: &[f32],
233        b_cols: usize,
234        beta: f32,
235        c: &mut [f32],
236    ) -> Result<(), SparseError> {
237        if b.len() != self.cols() * b_cols {
238            return Err(SparseError::SpMVDimensionMismatch {
239                matrix_cols: self.cols(),
240                x_len: b.len(),
241            });
242        }
243        if c.len() != self.rows() * b_cols {
244            return Err(SparseError::SpMVOutputDimensionMismatch {
245                matrix_rows: self.rows(),
246                y_len: c.len(),
247            });
248        }
249
250        let bs = self.block_size;
251        let bs2 = bs * bs;
252
253        // Scale C by beta
254        for ci in c.iter_mut() {
255            *ci *= beta;
256        }
257
258        // C += alpha * A * B
259        for bi in 0..self.block_rows {
260            let blk_start = self.offsets[bi] as usize;
261            let blk_end = self.offsets[bi + 1] as usize;
262
263            for blk_idx in blk_start..blk_end {
264                let bj = self.col_indices[blk_idx] as usize;
265                let block = &self.values[blk_idx * bs2..(blk_idx + 1) * bs2];
266
267                for li in 0..bs {
268                    let gi = bi * bs + li;
269                    if gi >= self.rows() {
270                        break;
271                    }
272                    for lj in 0..bs {
273                        let gj = bj * bs + lj;
274                        if gj >= self.cols() {
275                            continue;
276                        }
277                        let a_val = alpha * block[li * bs + lj];
278                        for k in 0..b_cols {
279                            c[gi * b_cols + k] += a_val * b[gj * b_cols + k];
280                        }
281                    }
282                }
283            }
284        }
285
286        Ok(())
287    }
288}