1use crate::error::{SparseError, SparseResult};
18use crate::formats::bsr::BSRMatrix;
19use scirs2_core::numeric::{One, SparseElement, Zero};
20use std::fmt::Debug;
21use std::ops::{Add, Mul};
22
23#[derive(Debug, Clone)]
33pub struct BSCMatrix<T> {
34 pub nrows: usize,
36 pub ncols: usize,
38 pub block_size: (usize, usize),
40 pub block_rows: usize,
42 pub block_cols: usize,
44 pub data: Vec<T>,
46 pub indices: Vec<usize>,
48 pub indptr: Vec<usize>,
50}
51
52impl<T> BSCMatrix<T>
53where
54 T: Clone + Copy + Zero + One + SparseElement + Debug + PartialEq,
55{
56 pub fn new(
69 data: Vec<T>,
70 indices: Vec<usize>,
71 indptr: Vec<usize>,
72 shape: (usize, usize),
73 block_size: (usize, usize),
74 ) -> SparseResult<Self> {
75 let (nrows, ncols) = shape;
76 let (r, c) = block_size;
77 if r == 0 || c == 0 {
78 return Err(SparseError::ValueError(
79 "BSC block dimensions must be positive".to_string(),
80 ));
81 }
82 let block_rows = nrows.div_ceil(r);
83 let block_cols = ncols.div_ceil(c);
84
85 if indptr.len() != block_cols + 1 {
86 return Err(SparseError::InconsistentData {
87 reason: format!(
88 "indptr length {} does not match block_cols+1 {}",
89 indptr.len(),
90 block_cols + 1
91 ),
92 });
93 }
94 let nnz_blocks = indices.len();
95 if data.len() != nnz_blocks * r * c {
96 return Err(SparseError::InconsistentData {
97 reason: format!(
98 "data length {} does not match nnz_blocks*r*c = {}*{}*{} = {}",
99 data.len(),
100 nnz_blocks,
101 r,
102 c,
103 nnz_blocks * r * c
104 ),
105 });
106 }
107 let last_ptr = *indptr.last().ok_or_else(|| SparseError::InconsistentData {
108 reason: "indptr is empty".to_string(),
109 })?;
110 if last_ptr != nnz_blocks {
111 return Err(SparseError::InconsistentData {
112 reason: "indptr last element must equal nnz_blocks".to_string(),
113 });
114 }
115 for bj in 0..block_cols {
116 if indptr[bj] > indptr[bj + 1] {
117 return Err(SparseError::InconsistentData {
118 reason: format!("indptr is not non-decreasing at position {}", bj),
119 });
120 }
121 }
122 for &bi in &indices {
123 if bi >= block_rows {
124 return Err(SparseError::IndexOutOfBounds {
125 index: (bi, 0),
126 shape: (block_rows, block_cols),
127 });
128 }
129 }
130
131 Ok(Self {
132 nrows,
133 ncols,
134 block_size,
135 block_rows,
136 block_cols,
137 data,
138 indices,
139 indptr,
140 })
141 }
142
143 pub fn zeros(shape: (usize, usize), block_size: (usize, usize)) -> SparseResult<Self> {
145 let (_nrows, ncols) = shape;
146 let (_r, c) = block_size;
147 if _r == 0 || c == 0 {
148 return Err(SparseError::ValueError(
149 "BSC block dimensions must be positive".to_string(),
150 ));
151 }
152 let block_cols = ncols.div_ceil(c);
153 Self::new(
154 vec![],
155 vec![],
156 vec![0usize; block_cols + 1],
157 shape,
158 block_size,
159 )
160 }
161
162 pub fn from_bsr(bsr: &BSRMatrix<T>) -> SparseResult<Self>
166 where
167 T: Add<Output = T> + Mul<Output = T>,
168 {
169 let nrows = bsr.nrows;
172 let ncols = bsr.ncols;
173 let (r, c) = bsr.block_size;
174 let block_rows = bsr.block_rows;
175 let block_cols = bsr.block_cols;
176 let nnz_blocks = bsr.indices.len();
177
178 let mut col_counts = vec![0usize; block_cols];
180 for &bj in &bsr.indices {
181 col_counts[bj] += 1;
182 }
183 let mut bsc_indptr = vec![0usize; block_cols + 1];
184 for j in 0..block_cols {
185 bsc_indptr[j + 1] = bsc_indptr[j] + col_counts[j];
186 }
187
188 let mut bsc_indices = vec![0usize; nnz_blocks];
189 let mut bsc_data = vec![<T as Zero>::zero(); nnz_blocks * r * c];
190 let mut cur = bsc_indptr[..block_cols].to_vec();
191
192 for bi in 0..block_rows {
193 for pos in bsr.indptr[bi]..bsr.indptr[bi + 1] {
194 let bj = bsr.indices[pos];
195 let dst = cur[bj];
196 cur[bj] += 1;
197 bsc_indices[dst] = bi;
198 let src_base = pos * r * c;
199 let dst_base = dst * r * c;
200 bsc_data[dst_base..dst_base + r * c]
202 .copy_from_slice(&bsr.data[src_base..src_base + r * c]);
203 }
204 }
205
206 Self::new(bsc_data, bsc_indices, bsc_indptr, (nrows, ncols), (r, c))
207 }
208
209 pub fn from_dense(
211 dense: &[T],
212 nrows: usize,
213 ncols: usize,
214 block_size: (usize, usize),
215 ) -> SparseResult<Self> {
216 if dense.len() != nrows * ncols {
217 return Err(SparseError::InconsistentData {
218 reason: format!(
219 "dense slice length {} does not match nrows*ncols = {}",
220 dense.len(),
221 nrows * ncols
222 ),
223 });
224 }
225 let (r, c) = block_size;
226 if r == 0 || c == 0 {
227 return Err(SparseError::ValueError(
228 "Block dimensions must be positive".to_string(),
229 ));
230 }
231 let block_rows = nrows.div_ceil(r);
232 let block_cols = ncols.div_ceil(c);
233 let zero = <T as Zero>::zero();
234
235 let mut data: Vec<T> = Vec::new();
236 let mut indices: Vec<usize> = Vec::new();
237 let mut indptr = vec![0usize; block_cols + 1];
238
239 for bj in 0..block_cols {
241 let col_start = bj * c;
242 let col_end = col_start + c;
243 for bi in 0..block_rows {
244 let row_start = bi * r;
245 let row_end = row_start + r;
246 let mut block = Vec::with_capacity(r * c);
247 let mut all_zero = true;
248 for row in row_start..row_end {
249 for col in col_start..col_end {
250 let val = if row < nrows && col < ncols {
251 dense[row * ncols + col]
252 } else {
253 zero
254 };
255 if val != zero {
256 all_zero = false;
257 }
258 block.push(val);
259 }
260 }
261 if !all_zero {
262 data.extend_from_slice(&block);
263 indices.push(bi);
264 }
265 }
266 indptr[bj + 1] = indices.len();
267 }
268
269 Self::new(data, indices, indptr, (nrows, ncols), block_size)
270 }
271
272 pub fn to_dense(&self) -> Vec<T> {
278 let zero = <T as Zero>::zero();
279 let mut dense = vec![zero; self.nrows * self.ncols];
280 let (r, c) = self.block_size;
281
282 for bj in 0..self.block_cols {
283 let col_start = bj * c;
284 for pos in self.indptr[bj]..self.indptr[bj + 1] {
285 let bi = self.indices[pos];
286 let row_start = bi * r;
287 let base = pos * r * c;
288 for local_row in 0..r {
289 let matrix_row = row_start + local_row;
290 if matrix_row >= self.nrows {
291 break;
292 }
293 for local_col in 0..c {
294 let matrix_col = col_start + local_col;
295 if matrix_col >= self.ncols {
296 break;
297 }
298 dense[matrix_row * self.ncols + matrix_col] =
299 self.data[base + local_row * c + local_col];
300 }
301 }
302 }
303 }
304 dense
305 }
306
307 pub fn to_bsr(&self) -> SparseResult<BSRMatrix<T>>
309 where
310 T: Add<Output = T> + Mul<Output = T>,
311 {
312 let (r, c) = self.block_size;
314 let nnz_blocks = self.indices.len();
315
316 let mut row_counts = vec![0usize; self.block_rows];
318 for &bi in &self.indices {
319 row_counts[bi] += 1;
320 }
321 let mut bsr_indptr = vec![0usize; self.block_rows + 1];
322 for i in 0..self.block_rows {
323 bsr_indptr[i + 1] = bsr_indptr[i] + row_counts[i];
324 }
325
326 let mut bsr_indices = vec![0usize; nnz_blocks];
327 let mut bsr_data = vec![<T as Zero>::zero(); nnz_blocks * r * c];
328 let mut cur = bsr_indptr[..self.block_rows].to_vec();
329
330 for bj in 0..self.block_cols {
331 for pos in self.indptr[bj]..self.indptr[bj + 1] {
332 let bi = self.indices[pos];
333 let dst = cur[bi];
334 cur[bi] += 1;
335 bsr_indices[dst] = bj;
336 let src_base = pos * r * c;
337 let dst_base = dst * r * c;
338 bsr_data[dst_base..dst_base + r * c]
339 .copy_from_slice(&self.data[src_base..src_base + r * c]);
340 }
341 }
342
343 BSRMatrix::new(
344 bsr_data,
345 bsr_indices,
346 bsr_indptr,
347 (self.nrows, self.ncols),
348 self.block_size,
349 )
350 }
351
352 pub fn spmv(&self, x: &[T]) -> SparseResult<Vec<T>>
360 where
361 T: Add<Output = T> + Mul<Output = T>,
362 {
363 if x.len() != self.ncols {
364 return Err(SparseError::DimensionMismatch {
365 expected: self.ncols,
366 found: x.len(),
367 });
368 }
369 let zero = <T as Zero>::zero();
370 let mut y = vec![zero; self.nrows];
371 let (r, c) = self.block_size;
372
373 for bj in 0..self.block_cols {
374 let col_start = bj * c;
375 let col_end = (col_start + c).min(self.ncols);
376
377 for pos in self.indptr[bj]..self.indptr[bj + 1] {
378 let bi = self.indices[pos];
379 let row_start = bi * r;
380 let row_end = (row_start + r).min(self.nrows);
381 let base = pos * r * c;
382
383 for local_row in 0..(row_end - row_start) {
384 let mut acc = zero;
385 for local_col in 0..(col_end - col_start) {
386 acc = acc
387 + self.data[base + local_row * c + local_col]
388 * x[col_start + local_col];
389 }
390 y[row_start + local_row] = y[row_start + local_row] + acc;
391 }
392 }
393 }
394 Ok(y)
395 }
396
397 pub fn transpose_to_bsr(&self) -> SparseResult<BSRMatrix<T>>
403 where
404 T: Add<Output = T> + Mul<Output = T>,
405 {
406 let bsr = self.to_bsr()?;
408 bsr.transpose()
409 }
410
411 pub fn nnz_blocks(&self) -> usize {
417 self.indices.len()
418 }
419
420 pub fn nnz(&self) -> usize {
422 let (r, c) = self.block_size;
423 self.indices.len() * r * c
424 }
425
426 pub fn shape(&self) -> (usize, usize) {
428 (self.nrows, self.ncols)
429 }
430
431 pub fn get(&self, row: usize, col: usize) -> T {
433 if row >= self.nrows || col >= self.ncols {
434 return <T as Zero>::zero();
435 }
436 let (r, c) = self.block_size;
437 let bi = row / r;
438 let bj = col / c;
439 let local_row = row % r;
440 let local_col = col % c;
441
442 for pos in self.indptr[bj]..self.indptr[bj + 1] {
443 if self.indices[pos] == bi {
444 let base = pos * r * c;
445 return self.data[base + local_row * c + local_col];
446 }
447 }
448 <T as Zero>::zero()
449 }
450}
451
452#[cfg(test)]
457mod tests {
458 use super::*;
459 use crate::formats::bsr::BSRMatrix;
460 use approx::assert_relative_eq;
461
462 fn make_4x4_bsr() -> BSRMatrix<f64> {
463 let data = vec![
464 1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, ];
467 let indices = vec![0, 1];
468 let indptr = vec![0, 1, 2];
469 BSRMatrix::new(data, indices, indptr, (4, 4), (2, 2)).expect("BSR construction failed")
470 }
471
472 #[test]
473 fn test_from_bsr_roundtrip() {
474 let bsr = make_4x4_bsr();
475 let bsc = BSCMatrix::from_bsr(&bsr).expect("from_bsr failed");
476 let bsr2 = bsc.to_bsr().expect("to_bsr failed");
477
478 let dense_bsr = bsr.to_dense();
479 let dense_bsr2 = bsr2.to_dense();
480 for (a, b) in dense_bsr.iter().zip(dense_bsr2.iter()) {
481 assert_relative_eq!(a, b, epsilon = 1e-12);
482 }
483 }
484
485 #[test]
486 fn test_from_dense() {
487 let dense = vec![
488 1.0_f64, 2.0, 0.0, 0.0, 3.0, 4.0, 0.0, 0.0, 0.0, 0.0, 5.0, 6.0, 0.0, 0.0, 7.0, 8.0,
489 ];
490 let bsc = BSCMatrix::from_dense(&dense, 4, 4, (2, 2)).expect("from_dense failed");
491 assert_eq!(bsc.nnz_blocks(), 2);
492 assert_eq!(bsc.get(0, 0), 1.0);
493 assert_eq!(bsc.get(3, 3), 8.0);
494 assert_eq!(bsc.get(0, 2), 0.0);
495 }
496
497 #[test]
498 fn test_spmv_matches_bsr() {
499 let bsr = make_4x4_bsr();
500 let bsc = BSCMatrix::from_bsr(&bsr).expect("from_bsr failed");
501 let x = vec![1.0_f64, 2.0, 3.0, 4.0];
502 let y_bsr = bsr.spmv(&x).expect("bsr spmv failed");
503 let y_bsc = bsc.spmv(&x).expect("bsc spmv failed");
504 for (a, b) in y_bsr.iter().zip(y_bsc.iter()) {
505 assert_relative_eq!(a, b, epsilon = 1e-12);
506 }
507 }
508
509 #[test]
510 fn test_to_dense_consistent() {
511 let dense_orig = vec![
512 1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
513 16.0,
514 ];
515 let bsc = BSCMatrix::from_dense(&dense_orig, 4, 4, (2, 2)).expect("from_dense failed");
516 let recovered = bsc.to_dense();
517 for (a, b) in recovered.iter().zip(dense_orig.iter()) {
518 assert_relative_eq!(a, b, epsilon = 1e-12);
519 }
520 }
521
522 #[test]
523 fn test_shape_and_nnz() {
524 let bsr = make_4x4_bsr();
525 let bsc = BSCMatrix::from_bsr(&bsr).expect("from_bsr failed");
526 assert_eq!(bsc.shape(), (4, 4));
527 assert_eq!(bsc.nnz_blocks(), 2);
528 assert_eq!(bsc.nnz(), 8);
529 }
530
531 #[test]
532 fn test_get_consistency_with_to_dense() {
533 let bsr = make_4x4_bsr();
534 let bsc = BSCMatrix::from_bsr(&bsr).expect("from_bsr failed");
535 let dense = bsc.to_dense();
536 for i in 0..4 {
537 for j in 0..4 {
538 assert_relative_eq!(bsc.get(i, j), dense[i * 4 + j], epsilon = 1e-12);
539 }
540 }
541 }
542}