1use crate::csr::CsrMatrix;
7use crate::error::SparseError;
8use crate::ops::SparseOps;
9
10#[derive(Debug, Clone)]
15pub struct BsrMatrix {
16 block_rows: usize,
18 block_cols: usize,
20 block_size: usize,
22 offsets: Vec<u32>,
24 col_indices: Vec<u32>,
26 values: Vec<f32>,
29}
30
31impl BsrMatrix {
32 pub fn new(
46 block_rows: usize,
47 block_cols: usize,
48 block_size: usize,
49 offsets: Vec<u32>,
50 col_indices: Vec<u32>,
51 values: Vec<f32>,
52 ) -> Result<Self, SparseError> {
53 if offsets.len() != block_rows + 1 {
54 return Err(SparseError::InvalidOffsetsLength {
55 actual: offsets.len(),
56 expected: block_rows + 1,
57 });
58 }
59 let nnz_blocks = col_indices.len();
60 let expected_vals = nnz_blocks * block_size * block_size;
61 if values.len() != expected_vals {
62 return Err(SparseError::LengthMismatch {
63 col_len: expected_vals,
64 val_len: values.len(),
65 });
66 }
67 Ok(Self { block_rows, block_cols, block_size, offsets, col_indices, values })
68 }
69
70 pub fn from_dense(data: &[f32], rows: usize, cols: usize, block_size: usize) -> Self {
75 let br = rows.div_ceil(block_size);
76 let bc = cols.div_ceil(block_size);
77
78 let mut offsets = vec![0u32; br + 1];
79 let mut col_indices = Vec::new();
80 let mut values = Vec::new();
81 let bs2 = block_size * block_size;
82
83 for bi in 0..br {
84 for bj in 0..bc {
85 let mut block = vec![0.0f32; bs2];
86 let mut has_nonzero = false;
87 for li in 0..block_size {
88 for lj in 0..block_size {
89 let gi = bi * block_size + li;
90 let gj = bj * block_size + lj;
91 if gi < rows && gj < cols {
92 let val = data[gi * cols + gj];
93 block[li * block_size + lj] = val;
94 if val != 0.0 {
95 has_nonzero = true;
96 }
97 }
98 }
99 }
100 if has_nonzero {
101 col_indices.push(bj as u32);
102 values.extend_from_slice(&block);
103 }
104 }
105 offsets[bi + 1] = col_indices.len() as u32;
106 }
107
108 Self { block_rows: br, block_cols: bc, block_size, offsets, col_indices, values }
109 }
110
111 pub fn to_csr(&self) -> Result<CsrMatrix<f32>, SparseError> {
117 let rows = self.block_rows * self.block_size;
118 let cols = self.block_cols * self.block_size;
119 let bs = self.block_size;
120 let bs2 = bs * bs;
121
122 let mut csr_offsets = vec![0u32; rows + 1];
123 let mut csr_cols = Vec::new();
124 let mut csr_vals = Vec::new();
125
126 for bi in 0..self.block_rows {
127 let blk_start = self.offsets[bi] as usize;
128 let blk_end = self.offsets[bi + 1] as usize;
129
130 for li in 0..bs {
131 let global_row = bi * bs + li;
132 if global_row >= rows {
133 break;
134 }
135 for blk_idx in blk_start..blk_end {
136 let bj = self.col_indices[blk_idx] as usize;
137 for lj in 0..bs {
138 let global_col = bj * bs + lj;
139 if global_col >= cols {
140 continue;
141 }
142 let val = self.values[blk_idx * bs2 + li * bs + lj];
143 if val != 0.0 {
144 csr_cols.push(global_col as u32);
145 csr_vals.push(val);
146 }
147 }
148 }
149 csr_offsets[global_row + 1] = csr_cols.len() as u32;
150 }
151 }
152
153 CsrMatrix::new(rows, cols, csr_offsets, csr_cols, csr_vals)
154 }
155
156 pub fn rows(&self) -> usize {
158 self.block_rows * self.block_size
159 }
160
161 pub fn cols(&self) -> usize {
163 self.block_cols * self.block_size
164 }
165
166 pub fn nnz_blocks(&self) -> usize {
168 self.col_indices.len()
169 }
170
171 pub fn block_size(&self) -> usize {
173 self.block_size
174 }
175}
176
177impl SparseOps for BsrMatrix {
178 fn spmv(&self, alpha: f32, x: &[f32], beta: f32, y: &mut [f32]) -> Result<(), SparseError> {
179 if x.len() != self.cols() {
180 return Err(SparseError::SpMVDimensionMismatch {
181 matrix_cols: self.cols(),
182 x_len: x.len(),
183 });
184 }
185 if y.len() != self.rows() {
186 return Err(SparseError::SpMVOutputDimensionMismatch {
187 matrix_rows: self.rows(),
188 y_len: y.len(),
189 });
190 }
191
192 let bs = self.block_size;
193 let bs2 = bs * bs;
194
195 for yi in y.iter_mut() {
197 *yi *= beta;
198 }
199
200 for bi in 0..self.block_rows {
202 let blk_start = self.offsets[bi] as usize;
203 let blk_end = self.offsets[bi + 1] as usize;
204
205 for blk_idx in blk_start..blk_end {
206 let bj = self.col_indices[blk_idx] as usize;
207 let block = &self.values[blk_idx * bs2..(blk_idx + 1) * bs2];
208
209 for li in 0..bs {
210 let gi = bi * bs + li;
211 if gi >= y.len() {
212 break;
213 }
214 let mut sum = 0.0f32;
215 for lj in 0..bs {
216 let gj = bj * bs + lj;
217 if gj < x.len() {
218 sum += block[li * bs + lj] * x[gj];
219 }
220 }
221 y[gi] += alpha * sum;
222 }
223 }
224 }
225
226 Ok(())
227 }
228
229 fn spmm(
230 &self,
231 alpha: f32,
232 b: &[f32],
233 b_cols: usize,
234 beta: f32,
235 c: &mut [f32],
236 ) -> Result<(), SparseError> {
237 if b.len() != self.cols() * b_cols {
238 return Err(SparseError::SpMVDimensionMismatch {
239 matrix_cols: self.cols(),
240 x_len: b.len(),
241 });
242 }
243 if c.len() != self.rows() * b_cols {
244 return Err(SparseError::SpMVOutputDimensionMismatch {
245 matrix_rows: self.rows(),
246 y_len: c.len(),
247 });
248 }
249
250 let bs = self.block_size;
251 let bs2 = bs * bs;
252
253 for ci in c.iter_mut() {
255 *ci *= beta;
256 }
257
258 for bi in 0..self.block_rows {
260 let blk_start = self.offsets[bi] as usize;
261 let blk_end = self.offsets[bi + 1] as usize;
262
263 for blk_idx in blk_start..blk_end {
264 let bj = self.col_indices[blk_idx] as usize;
265 let block = &self.values[blk_idx * bs2..(blk_idx + 1) * bs2];
266
267 for li in 0..bs {
268 let gi = bi * bs + li;
269 if gi >= self.rows() {
270 break;
271 }
272 for lj in 0..bs {
273 let gj = bj * bs + lj;
274 if gj >= self.cols() {
275 continue;
276 }
277 let a_val = alpha * block[li * bs + lj];
278 for k in 0..b_cols {
279 c[gi * b_cols + k] += a_val * b[gj * b_cols + k];
280 }
281 }
282 }
283 }
284 }
285
286 Ok(())
287 }
288}