1use ndarray::{Array1, Array2, ArrayView1};
7use num_traits::Float;
8use std::fmt::{self, Debug};
9use std::ops::{Add, Div, Mul, Sub};
10
11use crate::coo_array::CooArray;
12use crate::csc_array::CscArray;
13use crate::csr_array::CsrArray;
14use crate::dia_array::DiaArray;
15use crate::dok_array::DokArray;
16use crate::error::{SparseError, SparseResult};
17use crate::lil_array::LilArray;
18use crate::sparray::{SparseArray, SparseSum};
19
20#[derive(Clone)]
35pub struct BsrArray<T>
36where
37 T: Float
38 + Add<Output = T>
39 + Sub<Output = T>
40 + Mul<Output = T>
41 + Div<Output = T>
42 + Debug
43 + Copy
44 + 'static
45 + std::ops::AddAssign,
46{
47 rows: usize,
49 cols: usize,
51 block_size: (usize, usize),
53 block_rows: usize,
55 #[allow(dead_code)]
57 block_cols: usize,
58 data: Vec<Vec<Vec<T>>>,
60 indices: Vec<Vec<usize>>,
62 indptr: Vec<usize>,
64}
65
66impl<T> BsrArray<T>
67where
68 T: Float
69 + Add<Output = T>
70 + Sub<Output = T>
71 + Mul<Output = T>
72 + Div<Output = T>
73 + Debug
74 + Copy
75 + 'static
76 + std::ops::AddAssign,
77{
78 pub fn new(
116 data: Vec<Vec<Vec<T>>>,
117 indices: Vec<Vec<usize>>,
118 indptr: Vec<usize>,
119 shape: (usize, usize),
120 block_size: (usize, usize),
121 ) -> SparseResult<Self> {
122 let (rows, cols) = shape;
123 let (r, c) = block_size;
124
125 if r == 0 || c == 0 {
126 return Err(SparseError::ValueError(
127 "Block dimensions must be positive".to_string(),
128 ));
129 }
130
131 #[allow(clippy::manual_div_ceil)]
133 let block_rows = (rows + r - 1) / r; #[allow(clippy::manual_div_ceil)]
135 let block_cols = (cols + c - 1) / c; if indptr.len() != block_rows + 1 {
139 return Err(SparseError::DimensionMismatch {
140 expected: block_rows + 1,
141 found: indptr.len(),
142 });
143 }
144
145 if data.len() != indptr[block_rows] {
146 return Err(SparseError::DimensionMismatch {
147 expected: indptr[block_rows],
148 found: data.len(),
149 });
150 }
151
152 if indices.len() != data.len() {
153 return Err(SparseError::DimensionMismatch {
154 expected: data.len(),
155 found: indices.len(),
156 });
157 }
158
159 for block in data.iter() {
160 if block.len() != r {
161 return Err(SparseError::DimensionMismatch {
162 expected: r,
163 found: block.len(),
164 });
165 }
166
167 for row in block.iter() {
168 if row.len() != c {
169 return Err(SparseError::DimensionMismatch {
170 expected: c,
171 found: row.len(),
172 });
173 }
174 }
175 }
176
177 for idx_vec in indices.iter() {
178 if idx_vec.len() != 1 {
179 return Err(SparseError::ValueError(
180 "Each index vector must contain exactly one block column index".to_string(),
181 ));
182 }
183 if idx_vec[0] >= block_cols {
184 return Err(SparseError::ValueError(format!(
185 "index {} out of bounds (max {})",
186 idx_vec[0],
187 block_cols - 1
188 )));
189 }
190 }
191
192 Ok(BsrArray {
193 rows,
194 cols,
195 block_size,
196 block_rows,
197 block_cols,
198 data,
199 indices,
200 indptr,
201 })
202 }
203
204 pub fn empty(shape: (usize, usize), block_size: (usize, usize)) -> SparseResult<Self> {
215 let (rows, cols) = shape;
216 let (r, c) = block_size;
217
218 if r == 0 || c == 0 {
219 return Err(SparseError::ValueError(
220 "Block dimensions must be positive".to_string(),
221 ));
222 }
223
224 #[allow(clippy::manual_div_ceil)]
226 let block_rows = (rows + r - 1) / r; #[allow(clippy::manual_div_ceil)]
228 let block_cols = (cols + c - 1) / c; let data = Vec::new();
232 let indices = Vec::new();
233 let indptr = vec![0; block_rows + 1];
234
235 Ok(BsrArray {
236 rows,
237 cols,
238 block_size,
239 block_rows,
240 block_cols,
241 data,
242 indices,
243 indptr,
244 })
245 }
246
247 pub fn from_triplets(
261 row: &[usize],
262 col: &[usize],
263 data: &[T],
264 shape: (usize, usize),
265 block_size: (usize, usize),
266 ) -> SparseResult<Self> {
267 if row.len() != col.len() || row.len() != data.len() {
268 return Err(SparseError::InconsistentData {
269 reason: "Lengths of row, col, and data arrays must be equal".to_string(),
270 });
271 }
272
273 let (rows, cols) = shape;
274 let (r, c) = block_size;
275
276 if r == 0 || c == 0 {
277 return Err(SparseError::ValueError(
278 "Block dimensions must be positive".to_string(),
279 ));
280 }
281
282 #[allow(clippy::manual_div_ceil)]
284 let block_rows = (rows + r - 1) / r; #[allow(clippy::manual_div_ceil)]
286 let block_cols = (cols + c - 1) / c; let mut block_data = std::collections::HashMap::new();
290
291 for (&row_idx, (&col_idx, &val)) in row.iter().zip(col.iter().zip(data.iter())) {
293 if row_idx >= rows || col_idx >= cols {
294 return Err(SparseError::IndexOutOfBounds {
295 index: (row_idx, col_idx),
296 shape,
297 });
298 }
299
300 let block_row = row_idx / r;
302 let block_col = col_idx / c;
303
304 let block_row_pos = row_idx % r;
306 let block_col_pos = col_idx % c;
307
308 let block = block_data.entry((block_row, block_col)).or_insert_with(|| {
310 let block = vec![vec![T::zero(); c]; r];
311 block
312 });
313
314 block[block_row_pos][block_col_pos] = val;
316 }
317
318 let mut rowswith_blocks: Vec<usize> = block_data.keys().map(|&(row_, _)| row_).collect();
320 rowswith_blocks.sort();
321 rowswith_blocks.dedup();
322
323 let mut indptr = vec![0; block_rows + 1];
325 let mut current_nnz = 0;
326
327 let mut data = Vec::new();
329 let mut indices = Vec::new();
330
331 for row_idx in 0..block_rows {
332 if rowswith_blocks.contains(&row_idx) {
333 let mut row_blocks: Vec<(usize, Vec<Vec<T>>)> = block_data
335 .iter()
336 .filter(|&(&(r, _), _)| r == row_idx)
337 .map(|(&(_, c), block)| (c, block.clone()))
338 .collect();
339
340 row_blocks.sort_by_key(|&(col_, _)| col_);
342
343 for (col, block) in row_blocks {
345 data.push(block);
346 indices.push(vec![col]);
347 current_nnz += 1;
348 }
349 }
350
351 indptr[row_idx + 1] = current_nnz;
352 }
353
354 BsrArray::new(data, indices, indptr, shape, block_size)
356 }
357
358 fn to_coo_internal(&self) -> (Vec<usize>, Vec<usize>, Vec<T>) {
360 let (r, c) = self.block_size;
361 let mut row_indices = Vec::new();
362 let mut col_indices = Vec::new();
363 let mut values = Vec::new();
364
365 for block_row in 0..self.block_rows {
366 for k in self.indptr[block_row]..self.indptr[block_row + 1] {
367 let block_col = self.indices[k][0];
368 let block = &self.data[k];
369
370 for (i, block_row_data) in block.iter().enumerate().take(r) {
372 let row = block_row * r + i;
373 if row < self.rows {
374 for (j, &value) in block_row_data.iter().enumerate().take(c) {
375 let col = block_col * c + j;
376 if col < self.cols && !value.is_zero() {
377 row_indices.push(row);
378 col_indices.push(col);
379 values.push(value);
380 }
381 }
382 }
383 }
384 }
385 }
386
387 (row_indices, col_indices, values)
388 }
389}
390
391impl<T> SparseArray<T> for BsrArray<T>
392where
393 T: Float
394 + Add<Output = T>
395 + Sub<Output = T>
396 + Mul<Output = T>
397 + Div<Output = T>
398 + Debug
399 + Copy
400 + 'static
401 + std::ops::AddAssign,
402{
403 fn shape(&self) -> (usize, usize) {
404 (self.rows, self.cols)
405 }
406
407 fn nnz(&self) -> usize {
408 let mut count = 0;
409
410 for block in &self.data {
411 for row in block {
412 for &val in row {
413 if !val.is_zero() {
414 count += 1;
415 }
416 }
417 }
418 }
419
420 count
421 }
422
423 fn dtype(&self) -> &str {
424 "float" }
426
427 fn to_array(&self) -> Array2<T> {
428 let mut result = Array2::zeros((self.rows, self.cols));
429 let (r, c) = self.block_size;
430
431 for block_row in 0..self.block_rows {
432 for k in self.indptr[block_row]..self.indptr[block_row + 1] {
433 let block_col = self.indices[k][0];
434 let block = &self.data[k];
435
436 for (i, block_row_data) in block.iter().enumerate().take(r) {
438 let row = block_row * r + i;
439 if row < self.rows {
440 for (j, &value) in block_row_data.iter().enumerate().take(c) {
441 let col = block_col * c + j;
442 if col < self.cols {
443 result[[row, col]] = value;
444 }
445 }
446 }
447 }
448 }
449 }
450
451 result
452 }
453
454 fn toarray(&self) -> Array2<T> {
455 self.to_array()
456 }
457
458 fn to_coo(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
459 let (row_indices, col_indices, values) = self.to_coo_internal();
460 CooArray::from_triplets(
461 &row_indices,
462 &col_indices,
463 &values,
464 (self.rows, self.cols),
465 false,
466 )
467 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
468 }
469
470 fn to_csr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
471 let (row_indices, col_indices, values) = self.to_coo_internal();
472 CsrArray::from_triplets(
473 &row_indices,
474 &col_indices,
475 &values,
476 (self.rows, self.cols),
477 false,
478 )
479 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
480 }
481
482 fn to_csc(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
483 let (row_indices, col_indices, values) = self.to_coo_internal();
484 CscArray::from_triplets(
485 &row_indices,
486 &col_indices,
487 &values,
488 (self.rows, self.cols),
489 false,
490 )
491 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
492 }
493
494 fn to_dok(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
495 let (row_indices, col_indices, values) = self.to_coo_internal();
496 DokArray::from_triplets(&row_indices, &col_indices, &values, (self.rows, self.cols))
497 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
498 }
499
500 fn to_lil(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
501 let (row_indices, col_indices, values) = self.to_coo_internal();
502 LilArray::from_triplets(&row_indices, &col_indices, &values, (self.rows, self.cols))
503 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
504 }
505
506 fn to_dia(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
507 let (row_indices, col_indices, values) = self.to_coo_internal();
508 DiaArray::from_triplets(&row_indices, &col_indices, &values, (self.rows, self.cols))
509 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
510 }
511
512 fn to_bsr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
513 Ok(Box::new(self.clone()))
514 }
515
516 fn add(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
517 let csr_self = self.to_csr()?;
519 let csr_other = other.to_csr()?;
520 csr_self.add(&*csr_other)
521 }
522
523 fn sub(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
524 let csr_self = self.to_csr()?;
526 let csr_other = other.to_csr()?;
527 csr_self.sub(&*csr_other)
528 }
529
530 fn mul(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
531 let csr_self = self.to_csr()?;
533 let csr_other = other.to_csr()?;
534 csr_self.mul(&*csr_other)
535 }
536
537 fn div(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
538 let csr_self = self.to_csr()?;
540 let csr_other = other.to_csr()?;
541 csr_self.div(&*csr_other)
542 }
543
544 fn dot(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
545 let (_, n) = self.shape();
546 let (p, q) = other.shape();
547
548 if n != p {
549 return Err(SparseError::DimensionMismatch {
550 expected: n,
551 found: p,
552 });
553 }
554
555 if q == 1 {
557 let other_array = other.to_array();
559 let vec_view = other_array.column(0);
560
561 let result = self.dot_vector(&vec_view)?;
563
564 let mut rows = Vec::new();
566 let mut cols = Vec::new();
567 let mut values = Vec::new();
568
569 for (i, &val) in result.iter().enumerate() {
570 if !val.is_zero() {
571 rows.push(i);
572 cols.push(0);
573 values.push(val);
574 }
575 }
576
577 CooArray::from_triplets(&rows, &cols, &values, (result.len(), 1), false)
578 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
579 } else {
580 let csr_self = self.to_csr()?;
582 csr_self.dot(other)
583 }
584 }
585
586 fn dot_vector(&self, other: &ArrayView1<T>) -> SparseResult<Array1<T>> {
587 let (rows, cols) = self.shape();
588 let (r, c) = self.block_size;
589
590 if cols != other.len() {
591 return Err(SparseError::DimensionMismatch {
592 expected: cols,
593 found: other.len(),
594 });
595 }
596
597 let mut result = Array1::zeros(rows);
598
599 for block_row in 0..self.block_rows {
600 for k in self.indptr[block_row]..self.indptr[block_row + 1] {
601 let block_col = self.indices[k][0];
602 let block = &self.data[k];
603
604 for (i, block_row_data) in block.iter().enumerate().take(r) {
606 let row = block_row * r + i;
607 if row < self.rows {
608 for (j, &value) in block_row_data.iter().enumerate().take(c) {
609 let col = block_col * c + j;
610 if col < self.cols {
611 result[row] += value * other[col];
612 }
613 }
614 }
615 }
616 }
617 }
618
619 Ok(result)
620 }
621
622 fn transpose(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
623 self.to_coo()?.transpose()?.to_bsr()
625 }
626
627 fn copy(&self) -> Box<dyn SparseArray<T>> {
628 Box::new(self.clone())
629 }
630
631 fn get(&self, i: usize, j: usize) -> T {
632 if i >= self.rows || j >= self.cols {
633 return T::zero();
634 }
635
636 let (r, c) = self.block_size;
637 let block_row = i / r;
638 let block_col = j / c;
639 let block_row_pos = i % r;
640 let block_col_pos = j % c;
641
642 for k in self.indptr[block_row]..self.indptr[block_row + 1] {
644 if self.indices[k][0] == block_col {
645 return self.data[k][block_row_pos][block_col_pos];
646 }
647 }
648
649 T::zero()
650 }
651
652 fn set(&mut self, i: usize, j: usize, value: T) -> SparseResult<()> {
653 if i >= self.rows || j >= self.cols {
654 return Err(SparseError::IndexOutOfBounds {
655 index: (i, j),
656 shape: (self.rows, self.cols),
657 });
658 }
659
660 let (r, c) = self.block_size;
661 let block_row = i / r;
662 let block_col = j / c;
663 let block_row_pos = i % r;
664 let block_col_pos = j % c;
665
666 for k in self.indptr[block_row]..self.indptr[block_row + 1] {
668 if self.indices[k][0] == block_col {
669 self.data[k][block_row_pos][block_col_pos] = value;
671 return Ok(());
672 }
673 }
674
675 if !value.is_zero() {
677 let pos = self.indptr[block_row + 1];
679
680 let mut block = vec![vec![T::zero(); c]; r];
682 block[block_row_pos][block_col_pos] = value;
683
684 self.data.insert(pos, block);
686 self.indices.insert(pos, vec![block_col]);
687
688 for k in (block_row + 1)..=self.block_rows {
690 self.indptr[k] += 1;
691 }
692
693 Ok(())
694 } else {
695 Ok(())
697 }
698 }
699
700 fn eliminate_zeros(&mut self) {
701 let mut new_data = Vec::new();
703 let mut new_indices = Vec::new();
704 let mut new_indptr = vec![0];
705 let mut current_nnz = 0;
706
707 for block_row in 0..self.block_rows {
708 for k in self.indptr[block_row]..self.indptr[block_row + 1] {
709 let block_col = self.indices[k][0];
710 let block = &self.data[k];
711
712 let mut has_nonzero = false;
714 for row in block {
715 for &val in row {
716 if !val.is_zero() {
717 has_nonzero = true;
718 break;
719 }
720 }
721 if has_nonzero {
722 break;
723 }
724 }
725
726 if has_nonzero {
727 new_data.push(block.clone());
728 new_indices.push(vec![block_col]);
729 current_nnz += 1;
730 }
731 }
732
733 new_indptr.push(current_nnz);
734 }
735
736 self.data = new_data;
737 self.indices = new_indices;
738 self.indptr = new_indptr;
739 }
740
741 fn sort_indices(&mut self) {
742 let mut new_data = Vec::new();
744 let mut new_indices = Vec::new();
745 let mut new_indptr = vec![0];
746 let mut current_nnz = 0;
747
748 for block_row in 0..self.block_rows {
749 let mut row_blocks = Vec::new();
751 for k in self.indptr[block_row]..self.indptr[block_row + 1] {
752 row_blocks.push((self.indices[k][0], self.data[k].clone()));
753 }
754
755 row_blocks.sort_by_key(|&(col_, _)| col_);
757
758 for (col, block) in row_blocks {
760 new_data.push(block);
761 new_indices.push(vec![col]);
762 current_nnz += 1;
763 }
764
765 new_indptr.push(current_nnz);
766 }
767
768 self.data = new_data;
769 self.indices = new_indices;
770 self.indptr = new_indptr;
771 }
772
773 fn sorted_indices(&self) -> Box<dyn SparseArray<T>> {
774 let mut result = self.clone();
775 result.sort_indices();
776 Box::new(result)
777 }
778
779 fn has_sorted_indices(&self) -> bool {
780 for block_row in 0..self.block_rows {
781 let mut prev_col = None;
782
783 for k in self.indptr[block_row]..self.indptr[block_row + 1] {
784 let col = self.indices[k][0];
785
786 if let Some(prev) = prev_col {
787 if col <= prev {
788 return false;
789 }
790 }
791
792 prev_col = Some(col);
793 }
794 }
795
796 true
797 }
798
799 fn sum(&self, axis: Option<usize>) -> SparseResult<SparseSum<T>> {
800 match axis {
801 None => {
802 let mut total = T::zero();
804
805 for block in &self.data {
806 for row in block {
807 for &val in row {
808 total += val;
809 }
810 }
811 }
812
813 Ok(SparseSum::Scalar(total))
814 }
815 Some(0) => {
816 let mut result = vec![T::zero(); self.cols];
818 let (r, c) = self.block_size;
819
820 for block_row in 0..self.block_rows {
821 for k in self.indptr[block_row]..self.indptr[block_row + 1] {
822 let block_col = self.indices[k][0];
823 let block = &self.data[k];
824
825 for block_row_data in block.iter().take(r) {
826 for (j, &value) in block_row_data.iter().enumerate().take(c) {
827 let col = block_col * c + j;
828 if col < self.cols {
829 result[col] += value;
830 }
831 }
832 }
833 }
834 }
835
836 let mut row_indices = Vec::new();
838 let mut col_indices = Vec::new();
839 let mut values = Vec::new();
840
841 for (j, &val) in result.iter().enumerate() {
842 if !val.is_zero() {
843 row_indices.push(0);
844 col_indices.push(j);
845 values.push(val);
846 }
847 }
848
849 match CooArray::from_triplets(
850 &row_indices,
851 &col_indices,
852 &values,
853 (1, self.cols),
854 false,
855 ) {
856 Ok(array) => Ok(SparseSum::SparseArray(Box::new(array))),
857 Err(e) => Err(e),
858 }
859 }
860 Some(1) => {
861 let mut result = vec![T::zero(); self.rows];
863 let (r, c) = self.block_size;
864
865 for block_row in 0..self.block_rows {
866 for k in self.indptr[block_row]..self.indptr[block_row + 1] {
867 let block = &self.data[k];
868
869 for (i, block_row_data) in block.iter().enumerate().take(r) {
870 let row = block_row * r + i;
871 if row < self.rows {
872 for &value in block_row_data.iter().take(c) {
873 result[row] += value;
874 }
875 }
876 }
877 }
878 }
879
880 let mut row_indices = Vec::new();
882 let mut col_indices = Vec::new();
883 let mut values = Vec::new();
884
885 for (i, &val) in result.iter().enumerate() {
886 if !val.is_zero() {
887 row_indices.push(i);
888 col_indices.push(0);
889 values.push(val);
890 }
891 }
892
893 match CooArray::from_triplets(
894 &row_indices,
895 &col_indices,
896 &values,
897 (self.rows, 1),
898 false,
899 ) {
900 Ok(array) => Ok(SparseSum::SparseArray(Box::new(array))),
901 Err(e) => Err(e),
902 }
903 }
904 _ => Err(SparseError::InvalidAxis),
905 }
906 }
907
908 fn max(&self) -> T {
909 let mut max_val = T::neg_infinity();
910
911 for block in &self.data {
912 for row in block {
913 for &val in row {
914 max_val = max_val.max(val);
915 }
916 }
917 }
918
919 if max_val == T::neg_infinity() {
921 T::zero()
922 } else {
923 max_val
924 }
925 }
926
927 fn min(&self) -> T {
928 let mut min_val = T::infinity();
929 let mut has_nonzero = false;
930
931 for block in &self.data {
932 for row in block {
933 for &val in row {
934 if !val.is_zero() {
935 has_nonzero = true;
936 min_val = min_val.min(val);
937 }
938 }
939 }
940 }
941
942 if !has_nonzero {
944 T::zero()
945 } else {
946 min_val
947 }
948 }
949
950 fn find(&self) -> (Array1<usize>, Array1<usize>, Array1<T>) {
951 let (row_indices, col_indices, values) = self.to_coo_internal();
952
953 (
954 Array1::from_vec(row_indices),
955 Array1::from_vec(col_indices),
956 Array1::from_vec(values),
957 )
958 }
959
960 fn slice(
961 &self,
962 row_range: (usize, usize),
963 col_range: (usize, usize),
964 ) -> SparseResult<Box<dyn SparseArray<T>>> {
965 let (start_row, end_row) = row_range;
966 let (start_col, end_col) = col_range;
967 let (rows, cols) = self.shape();
968
969 if start_row >= rows || end_row > rows || start_col >= cols || end_col > cols {
970 return Err(SparseError::IndexOutOfBounds {
971 index: (start_row.max(end_row), start_col.max(end_col)),
972 shape: (rows, cols),
973 });
974 }
975
976 if start_row >= end_row || start_col >= end_col {
977 return Err(SparseError::InvalidSliceRange);
978 }
979
980 let coo = self.to_coo()?;
982 coo.slice(row_range, col_range)?.to_bsr()
983 }
984
985 fn as_any(&self) -> &dyn std::any::Any {
986 self
987 }
988}
989
990impl<T> fmt::Display for BsrArray<T>
992where
993 T: Float
994 + Add<Output = T>
995 + Sub<Output = T>
996 + Mul<Output = T>
997 + Div<Output = T>
998 + Debug
999 + Copy
1000 + 'static
1001 + std::ops::AddAssign,
1002{
1003 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1004 writeln!(
1005 f,
1006 "BsrArray of shape {:?} with {} stored elements",
1007 (self.rows, self.cols),
1008 self.nnz()
1009 )?;
1010 writeln!(f, "Block size: {:?}", self.block_size)?;
1011 writeln!(f, "Number of blocks: {}", self.data.len())?;
1012
1013 if self.data.len() <= 5 {
1014 for block_row in 0..self.block_rows {
1015 for k in self.indptr[block_row]..self.indptr[block_row + 1] {
1016 let block_col = self.indices[k][0];
1017 let block = &self.data[k];
1018
1019 writeln!(f, "Block at ({block_row}, {block_col}): ")?;
1020 for row in block {
1021 write!(f, " [")?;
1022 for (j, &val) in row.iter().enumerate() {
1023 if j > 0 {
1024 write!(f, ", ")?;
1025 }
1026 write!(f, "{val:?}")?;
1027 }
1028 writeln!(f, "]")?;
1029 }
1030 }
1031 }
1032 } else {
1033 writeln!(f, "({} blocks total)", self.data.len())?;
1034 }
1035
1036 Ok(())
1037 }
1038}
1039
1040#[cfg(test)]
1041mod tests {
1042 use super::*;
1043
1044 #[test]
1045 fn test_bsr_array_create() {
1046 let block1 = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
1053 let block2 = vec![vec![5.0, 6.0], vec![7.0, 8.0]];
1054
1055 let data = vec![block1, block2];
1056 let indices = vec![vec![0], vec![1]];
1057 let indptr = vec![0, 1, 2];
1058
1059 let array = BsrArray::new(data, indices, indptr, (4, 4), (2, 2)).unwrap();
1060
1061 assert_eq!(array.shape(), (4, 4));
1062 assert_eq!(array.block_size, (2, 2));
1063 assert_eq!(array.nnz(), 8); assert_eq!(array.get(0, 0), 1.0);
1067 assert_eq!(array.get(0, 1), 2.0);
1068 assert_eq!(array.get(1, 0), 3.0);
1069 assert_eq!(array.get(1, 1), 4.0);
1070 assert_eq!(array.get(2, 2), 5.0);
1071 assert_eq!(array.get(2, 3), 6.0);
1072 assert_eq!(array.get(3, 2), 7.0);
1073 assert_eq!(array.get(3, 3), 8.0);
1074 assert_eq!(array.get(0, 2), 0.0); }
1076
1077 #[test]
1078 fn test_bsr_array_from_triplets() {
1079 let rows = vec![0, 0, 1, 1, 2, 2, 3, 3];
1081 let cols = vec![0, 1, 0, 1, 2, 3, 2, 3];
1082 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
1083 let shape = (4, 4);
1084 let block_size = (2, 2);
1085
1086 let array = BsrArray::from_triplets(&rows, &cols, &data, shape, block_size).unwrap();
1087
1088 assert_eq!(array.shape(), (4, 4));
1089 assert_eq!(array.block_size, (2, 2));
1090 assert_eq!(array.nnz(), 8);
1091
1092 assert_eq!(array.get(0, 0), 1.0);
1094 assert_eq!(array.get(0, 1), 2.0);
1095 assert_eq!(array.get(1, 0), 3.0);
1096 assert_eq!(array.get(1, 1), 4.0);
1097 assert_eq!(array.get(2, 2), 5.0);
1098 assert_eq!(array.get(2, 3), 6.0);
1099 assert_eq!(array.get(3, 2), 7.0);
1100 assert_eq!(array.get(3, 3), 8.0);
1101 assert_eq!(array.get(0, 2), 0.0); }
1103
1104 #[test]
1105 fn test_bsr_array_conversion() {
1106 let block1 = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
1108 let block2 = vec![vec![5.0, 6.0], vec![7.0, 8.0]];
1109
1110 let data = vec![block1, block2];
1111 let indices = vec![vec![0], vec![1]];
1112 let indptr = vec![0, 1, 2];
1113
1114 let array = BsrArray::new(data, indices, indptr, (4, 4), (2, 2)).unwrap();
1115
1116 let coo = array.to_coo().unwrap();
1118 assert_eq!(coo.shape(), (4, 4));
1119 assert_eq!(coo.nnz(), 8);
1120
1121 let csr = array.to_csr().unwrap();
1123 assert_eq!(csr.shape(), (4, 4));
1124 assert_eq!(csr.nnz(), 8);
1125
1126 let dense = array.to_array();
1128 let expected = Array2::from_shape_vec(
1129 (4, 4),
1130 vec![
1131 1.0, 2.0, 0.0, 0.0, 3.0, 4.0, 0.0, 0.0, 0.0, 0.0, 5.0, 6.0, 0.0, 0.0, 7.0, 8.0,
1132 ],
1133 )
1134 .unwrap();
1135 assert_eq!(dense, expected);
1136 }
1137
1138 #[test]
1139 fn test_bsr_array_operations() {
1140 let block1 = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
1142 let data1 = vec![block1];
1143 let indices1 = vec![vec![0]];
1144 let indptr1 = vec![0, 1];
1145 let array1 = BsrArray::new(data1, indices1, indptr1, (2, 2), (2, 2)).unwrap();
1146
1147 let block2 = vec![vec![5.0, 6.0], vec![7.0, 8.0]];
1148 let data2 = vec![block2];
1149 let indices2 = vec![vec![0]];
1150 let indptr2 = vec![0, 1];
1151 let array2 = BsrArray::new(data2, indices2, indptr2, (2, 2), (2, 2)).unwrap();
1152
1153 let sum = array1.add(&array2).unwrap();
1155 assert_eq!(sum.shape(), (2, 2));
1156 assert_eq!(sum.get(0, 0), 6.0); assert_eq!(sum.get(0, 1), 8.0); assert_eq!(sum.get(1, 0), 10.0); assert_eq!(sum.get(1, 1), 12.0); let product = array1.mul(&array2).unwrap();
1163 assert_eq!(product.shape(), (2, 2));
1164 assert_eq!(product.get(0, 0), 5.0); assert_eq!(product.get(0, 1), 12.0); assert_eq!(product.get(1, 0), 21.0); assert_eq!(product.get(1, 1), 32.0); let dot = array1.dot(&array2).unwrap();
1171 assert_eq!(dot.shape(), (2, 2));
1172 assert_eq!(dot.get(0, 0), 19.0); assert_eq!(dot.get(0, 1), 22.0); assert_eq!(dot.get(1, 0), 43.0); assert_eq!(dot.get(1, 1), 50.0); }
1177
1178 #[test]
1179 fn test_bsr_array_dot_vector() {
1180 let block1 = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
1182 let block2 = vec![vec![5.0, 6.0], vec![7.0, 8.0]];
1183
1184 let data = vec![block1, block2];
1185 let indices = vec![vec![0], vec![1]];
1186 let indptr = vec![0, 1, 2];
1187
1188 let array = BsrArray::new(data, indices, indptr, (4, 4), (2, 2)).unwrap();
1189
1190 let vector = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
1192
1193 let result = array.dot_vector(&vector.view()).unwrap();
1195
1196 let expected = Array1::from_vec(vec![5.0, 11.0, 39.0, 53.0]);
1200 assert_eq!(result, expected);
1201 }
1202
1203 #[test]
1204 fn test_bsr_array_sum() {
1205 let block1 = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
1207 let block2 = vec![vec![5.0, 6.0], vec![7.0, 8.0]];
1208
1209 let data = vec![block1, block2];
1210 let indices = vec![vec![0], vec![1]];
1211 let indptr = vec![0, 1, 2];
1212
1213 let array = BsrArray::new(data, indices, indptr, (4, 4), (2, 2)).unwrap();
1214
1215 if let SparseSum::Scalar(sum) = array.sum(None).unwrap() {
1217 assert_eq!(sum, 36.0); } else {
1219 panic!("Expected SparseSum::Scalar");
1220 }
1221
1222 if let SparseSum::SparseArray(row_sum) = array.sum(Some(0)).unwrap() {
1224 assert_eq!(row_sum.shape(), (1, 4));
1225 assert_eq!(row_sum.get(0, 0), 4.0); assert_eq!(row_sum.get(0, 1), 6.0); assert_eq!(row_sum.get(0, 2), 12.0); assert_eq!(row_sum.get(0, 3), 14.0); } else {
1230 panic!("Expected SparseSum::SparseArray");
1231 }
1232
1233 if let SparseSum::SparseArray(col_sum) = array.sum(Some(1)).unwrap() {
1235 assert_eq!(col_sum.shape(), (4, 1));
1236 assert_eq!(col_sum.get(0, 0), 3.0); assert_eq!(col_sum.get(1, 0), 7.0); assert_eq!(col_sum.get(2, 0), 11.0); assert_eq!(col_sum.get(3, 0), 15.0); } else {
1241 panic!("Expected SparseSum::SparseArray");
1242 }
1243 }
1244}