1use ndarray::{Array1, Array2, ArrayView1};
7use num_traits::Float;
8use std::fmt::{self, Debug};
9use std::ops::{Add, Div, Mul, Sub};
10
11use crate::coo_array::CooArray;
12use crate::csc_array::CscArray;
13use crate::csr_array::CsrArray;
14use crate::dia_array::DiaArray;
15use crate::dok_array::DokArray;
16use crate::error::{SparseError, SparseResult};
17use crate::lil_array::LilArray;
18use crate::sparray::{SparseArray, SparseSum};
19
20#[derive(Clone)]
35pub struct BsrArray<T>
36where
37 T: Float
38 + Add<Output = T>
39 + Sub<Output = T>
40 + Mul<Output = T>
41 + Div<Output = T>
42 + Debug
43 + Copy
44 + 'static
45 + std::ops::AddAssign,
46{
47 rows: usize,
49 cols: usize,
51 block_size: (usize, usize),
53 block_rows: usize,
55 _block_cols: usize,
57 data: Vec<Vec<Vec<T>>>,
59 indices: Vec<Vec<usize>>,
61 indptr: Vec<usize>,
63}
64
65impl<T> BsrArray<T>
66where
67 T: Float
68 + Add<Output = T>
69 + Sub<Output = T>
70 + Mul<Output = T>
71 + Div<Output = T>
72 + Debug
73 + Copy
74 + 'static
75 + std::ops::AddAssign,
76{
77 pub fn new(
115 data: Vec<Vec<Vec<T>>>,
116 indices: Vec<Vec<usize>>,
117 indptr: Vec<usize>,
118 shape: (usize, usize),
119 block_size: (usize, usize),
120 ) -> SparseResult<Self> {
121 let (rows, cols) = shape;
122 let (r, c) = block_size;
123
124 if r == 0 || c == 0 {
125 return Err(SparseError::ValueError(
126 "Block dimensions must be positive".to_string(),
127 ));
128 }
129
130 #[allow(clippy::manual_div_ceil)]
132 let block_rows = (rows + r - 1) / r; #[allow(clippy::manual_div_ceil)]
134 let _block_cols = (cols + c - 1) / c; if indptr.len() != block_rows + 1 {
138 return Err(SparseError::DimensionMismatch {
139 expected: block_rows + 1,
140 found: indptr.len(),
141 });
142 }
143
144 if data.len() != indptr[block_rows] {
145 return Err(SparseError::DimensionMismatch {
146 expected: indptr[block_rows],
147 found: data.len(),
148 });
149 }
150
151 if indices.len() != data.len() {
152 return Err(SparseError::DimensionMismatch {
153 expected: data.len(),
154 found: indices.len(),
155 });
156 }
157
158 for block in data.iter() {
159 if block.len() != r {
160 return Err(SparseError::DimensionMismatch {
161 expected: r,
162 found: block.len(),
163 });
164 }
165
166 for row in block.iter() {
167 if row.len() != c {
168 return Err(SparseError::DimensionMismatch {
169 expected: c,
170 found: row.len(),
171 });
172 }
173 }
174 }
175
176 for idx_vec in indices.iter() {
177 if idx_vec.len() != 1 {
178 return Err(SparseError::ValueError(
179 "Each index vector must contain exactly one block column index".to_string(),
180 ));
181 }
182 if idx_vec[0] >= _block_cols {
183 return Err(SparseError::ValueError(format!(
184 "index {} out of bounds (max {})",
185 idx_vec[0],
186 _block_cols - 1
187 )));
188 }
189 }
190
191 Ok(BsrArray {
192 rows,
193 cols,
194 block_size,
195 block_rows,
196 _block_cols,
197 data,
198 indices,
199 indptr,
200 })
201 }
202
203 pub fn empty(shape: (usize, usize), block_size: (usize, usize)) -> SparseResult<Self> {
214 let (rows, cols) = shape;
215 let (r, c) = block_size;
216
217 if r == 0 || c == 0 {
218 return Err(SparseError::ValueError(
219 "Block dimensions must be positive".to_string(),
220 ));
221 }
222
223 #[allow(clippy::manual_div_ceil)]
225 let block_rows = (rows + r - 1) / r; #[allow(clippy::manual_div_ceil)]
227 let _block_cols = (cols + c - 1) / c; let data = Vec::new();
231 let indices = Vec::new();
232 let indptr = vec![0; block_rows + 1];
233
234 Ok(BsrArray {
235 rows,
236 cols,
237 block_size,
238 block_rows,
239 _block_cols,
240 data,
241 indices,
242 indptr,
243 })
244 }
245
246 pub fn from_triplets(
260 row: &[usize],
261 col: &[usize],
262 data: &[T],
263 shape: (usize, usize),
264 block_size: (usize, usize),
265 ) -> SparseResult<Self> {
266 if row.len() != col.len() || row.len() != data.len() {
267 return Err(SparseError::InconsistentData {
268 reason: "Lengths of row, col, and data arrays must be equal".to_string(),
269 });
270 }
271
272 let (rows, cols) = shape;
273 let (r, c) = block_size;
274
275 if r == 0 || c == 0 {
276 return Err(SparseError::ValueError(
277 "Block dimensions must be positive".to_string(),
278 ));
279 }
280
281 #[allow(clippy::manual_div_ceil)]
283 let block_rows = (rows + r - 1) / r; #[allow(clippy::manual_div_ceil)]
285 let _block_cols = (cols + c - 1) / c; let mut block_data = std::collections::HashMap::new();
289
290 for (&row_idx, (&col_idx, &val)) in row.iter().zip(col.iter().zip(data.iter())) {
292 if row_idx >= rows || col_idx >= cols {
293 return Err(SparseError::IndexOutOfBounds {
294 index: (row_idx, col_idx),
295 shape,
296 });
297 }
298
299 let block_row = row_idx / r;
301 let block_col = col_idx / c;
302
303 let block_row_pos = row_idx % r;
305 let block_col_pos = col_idx % c;
306
307 let block = block_data.entry((block_row, block_col)).or_insert_with(|| {
309 let block = vec![vec![T::zero(); c]; r];
310 block
311 });
312
313 block[block_row_pos][block_col_pos] = val;
315 }
316
317 let mut rows_with_blocks: Vec<usize> = block_data.keys().map(|&(row, _)| row).collect();
319 rows_with_blocks.sort();
320 rows_with_blocks.dedup();
321
322 let mut indptr = vec![0; block_rows + 1];
324 let mut current_nnz = 0;
325
326 let mut data = Vec::new();
328 let mut indices = Vec::new();
329
330 for row_idx in 0..block_rows {
331 if rows_with_blocks.contains(&row_idx) {
332 let mut row_blocks: Vec<(usize, Vec<Vec<T>>)> = block_data
334 .iter()
335 .filter(|&(&(r, _), _)| r == row_idx)
336 .map(|(&(_, c), block)| (c, block.clone()))
337 .collect();
338
339 row_blocks.sort_by_key(|&(col, _)| col);
341
342 for (col, block) in row_blocks {
344 data.push(block);
345 indices.push(vec![col]);
346 current_nnz += 1;
347 }
348 }
349
350 indptr[row_idx + 1] = current_nnz;
351 }
352
353 BsrArray::new(data, indices, indptr, shape, block_size)
355 }
356
357 fn to_coo_internal(&self) -> (Vec<usize>, Vec<usize>, Vec<T>) {
359 let (r, c) = self.block_size;
360 let mut row_indices = Vec::new();
361 let mut col_indices = Vec::new();
362 let mut values = Vec::new();
363
364 for block_row in 0..self.block_rows {
365 for k in self.indptr[block_row]..self.indptr[block_row + 1] {
366 let block_col = self.indices[k][0];
367 let block = &self.data[k];
368
369 for (i, block_row_data) in block.iter().enumerate().take(r) {
371 let row = block_row * r + i;
372 if row < self.rows {
373 for (j, &value) in block_row_data.iter().enumerate().take(c) {
374 let col = block_col * c + j;
375 if col < self.cols && !value.is_zero() {
376 row_indices.push(row);
377 col_indices.push(col);
378 values.push(value);
379 }
380 }
381 }
382 }
383 }
384 }
385
386 (row_indices, col_indices, values)
387 }
388}
389
390impl<T> SparseArray<T> for BsrArray<T>
391where
392 T: Float
393 + Add<Output = T>
394 + Sub<Output = T>
395 + Mul<Output = T>
396 + Div<Output = T>
397 + Debug
398 + Copy
399 + 'static
400 + std::ops::AddAssign,
401{
402 fn shape(&self) -> (usize, usize) {
403 (self.rows, self.cols)
404 }
405
406 fn nnz(&self) -> usize {
407 let mut count = 0;
408
409 for block in &self.data {
410 for row in block {
411 for &val in row {
412 if !val.is_zero() {
413 count += 1;
414 }
415 }
416 }
417 }
418
419 count
420 }
421
422 fn dtype(&self) -> &str {
423 "float" }
425
426 fn to_array(&self) -> Array2<T> {
427 let mut result = Array2::zeros((self.rows, self.cols));
428 let (r, c) = self.block_size;
429
430 for block_row in 0..self.block_rows {
431 for k in self.indptr[block_row]..self.indptr[block_row + 1] {
432 let block_col = self.indices[k][0];
433 let block = &self.data[k];
434
435 for (i, block_row_data) in block.iter().enumerate().take(r) {
437 let row = block_row * r + i;
438 if row < self.rows {
439 for (j, &value) in block_row_data.iter().enumerate().take(c) {
440 let col = block_col * c + j;
441 if col < self.cols {
442 result[[row, col]] = value;
443 }
444 }
445 }
446 }
447 }
448 }
449
450 result
451 }
452
453 fn toarray(&self) -> Array2<T> {
454 self.to_array()
455 }
456
457 fn to_coo(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
458 let (row_indices, col_indices, values) = self.to_coo_internal();
459 CooArray::from_triplets(
460 &row_indices,
461 &col_indices,
462 &values,
463 (self.rows, self.cols),
464 false,
465 )
466 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
467 }
468
469 fn to_csr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
470 let (row_indices, col_indices, values) = self.to_coo_internal();
471 CsrArray::from_triplets(
472 &row_indices,
473 &col_indices,
474 &values,
475 (self.rows, self.cols),
476 false,
477 )
478 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
479 }
480
481 fn to_csc(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
482 let (row_indices, col_indices, values) = self.to_coo_internal();
483 CscArray::from_triplets(
484 &row_indices,
485 &col_indices,
486 &values,
487 (self.rows, self.cols),
488 false,
489 )
490 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
491 }
492
493 fn to_dok(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
494 let (row_indices, col_indices, values) = self.to_coo_internal();
495 DokArray::from_triplets(&row_indices, &col_indices, &values, (self.rows, self.cols))
496 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
497 }
498
499 fn to_lil(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
500 let (row_indices, col_indices, values) = self.to_coo_internal();
501 LilArray::from_triplets(&row_indices, &col_indices, &values, (self.rows, self.cols))
502 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
503 }
504
505 fn to_dia(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
506 let (row_indices, col_indices, values) = self.to_coo_internal();
507 DiaArray::from_triplets(&row_indices, &col_indices, &values, (self.rows, self.cols))
508 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
509 }
510
511 fn to_bsr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
512 Ok(Box::new(self.clone()))
513 }
514
515 fn add(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
516 let csr_self = self.to_csr()?;
518 let csr_other = other.to_csr()?;
519 csr_self.add(&*csr_other)
520 }
521
522 fn sub(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
523 let csr_self = self.to_csr()?;
525 let csr_other = other.to_csr()?;
526 csr_self.sub(&*csr_other)
527 }
528
529 fn mul(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
530 let csr_self = self.to_csr()?;
532 let csr_other = other.to_csr()?;
533 csr_self.mul(&*csr_other)
534 }
535
536 fn div(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
537 let csr_self = self.to_csr()?;
539 let csr_other = other.to_csr()?;
540 csr_self.div(&*csr_other)
541 }
542
543 fn dot(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
544 let (_, n) = self.shape();
545 let (p, q) = other.shape();
546
547 if n != p {
548 return Err(SparseError::DimensionMismatch {
549 expected: n,
550 found: p,
551 });
552 }
553
554 if q == 1 {
556 let other_array = other.to_array();
558 let vec_view = other_array.column(0);
559
560 let result = self.dot_vector(&vec_view)?;
562
563 let mut rows = Vec::new();
565 let mut cols = Vec::new();
566 let mut values = Vec::new();
567
568 for (i, &val) in result.iter().enumerate() {
569 if !val.is_zero() {
570 rows.push(i);
571 cols.push(0);
572 values.push(val);
573 }
574 }
575
576 CooArray::from_triplets(&rows, &cols, &values, (result.len(), 1), false)
577 .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
578 } else {
579 let csr_self = self.to_csr()?;
581 csr_self.dot(other)
582 }
583 }
584
585 fn dot_vector(&self, other: &ArrayView1<T>) -> SparseResult<Array1<T>> {
586 let (rows, cols) = self.shape();
587 let (r, c) = self.block_size;
588
589 if cols != other.len() {
590 return Err(SparseError::DimensionMismatch {
591 expected: cols,
592 found: other.len(),
593 });
594 }
595
596 let mut result = Array1::zeros(rows);
597
598 for block_row in 0..self.block_rows {
599 for k in self.indptr[block_row]..self.indptr[block_row + 1] {
600 let block_col = self.indices[k][0];
601 let block = &self.data[k];
602
603 for (i, block_row_data) in block.iter().enumerate().take(r) {
605 let row = block_row * r + i;
606 if row < self.rows {
607 for (j, &value) in block_row_data.iter().enumerate().take(c) {
608 let col = block_col * c + j;
609 if col < self.cols {
610 result[row] += value * other[col];
611 }
612 }
613 }
614 }
615 }
616 }
617
618 Ok(result)
619 }
620
621 fn transpose(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
622 self.to_coo()?.transpose()?.to_bsr()
624 }
625
626 fn copy(&self) -> Box<dyn SparseArray<T>> {
627 Box::new(self.clone())
628 }
629
630 fn get(&self, i: usize, j: usize) -> T {
631 if i >= self.rows || j >= self.cols {
632 return T::zero();
633 }
634
635 let (r, c) = self.block_size;
636 let block_row = i / r;
637 let block_col = j / c;
638 let block_row_pos = i % r;
639 let block_col_pos = j % c;
640
641 for k in self.indptr[block_row]..self.indptr[block_row + 1] {
643 if self.indices[k][0] == block_col {
644 return self.data[k][block_row_pos][block_col_pos];
645 }
646 }
647
648 T::zero()
649 }
650
651 fn set(&mut self, i: usize, j: usize, value: T) -> SparseResult<()> {
652 if i >= self.rows || j >= self.cols {
653 return Err(SparseError::IndexOutOfBounds {
654 index: (i, j),
655 shape: (self.rows, self.cols),
656 });
657 }
658
659 let (r, c) = self.block_size;
660 let block_row = i / r;
661 let block_col = j / c;
662 let block_row_pos = i % r;
663 let block_col_pos = j % c;
664
665 for k in self.indptr[block_row]..self.indptr[block_row + 1] {
667 if self.indices[k][0] == block_col {
668 self.data[k][block_row_pos][block_col_pos] = value;
670 return Ok(());
671 }
672 }
673
674 if !value.is_zero() {
676 let pos = self.indptr[block_row + 1];
678
679 let mut block = vec![vec![T::zero(); c]; r];
681 block[block_row_pos][block_col_pos] = value;
682
683 self.data.insert(pos, block);
685 self.indices.insert(pos, vec![block_col]);
686
687 for k in (block_row + 1)..=self.block_rows {
689 self.indptr[k] += 1;
690 }
691
692 Ok(())
693 } else {
694 Ok(())
696 }
697 }
698
699 fn eliminate_zeros(&mut self) {
700 let mut new_data = Vec::new();
702 let mut new_indices = Vec::new();
703 let mut new_indptr = vec![0];
704 let mut current_nnz = 0;
705
706 for block_row in 0..self.block_rows {
707 for k in self.indptr[block_row]..self.indptr[block_row + 1] {
708 let block_col = self.indices[k][0];
709 let block = &self.data[k];
710
711 let mut has_nonzero = false;
713 for row in block {
714 for &val in row {
715 if !val.is_zero() {
716 has_nonzero = true;
717 break;
718 }
719 }
720 if has_nonzero {
721 break;
722 }
723 }
724
725 if has_nonzero {
726 new_data.push(block.clone());
727 new_indices.push(vec![block_col]);
728 current_nnz += 1;
729 }
730 }
731
732 new_indptr.push(current_nnz);
733 }
734
735 self.data = new_data;
736 self.indices = new_indices;
737 self.indptr = new_indptr;
738 }
739
740 fn sort_indices(&mut self) {
741 let mut new_data = Vec::new();
743 let mut new_indices = Vec::new();
744 let mut new_indptr = vec![0];
745 let mut current_nnz = 0;
746
747 for block_row in 0..self.block_rows {
748 let mut row_blocks = Vec::new();
750 for k in self.indptr[block_row]..self.indptr[block_row + 1] {
751 row_blocks.push((self.indices[k][0], self.data[k].clone()));
752 }
753
754 row_blocks.sort_by_key(|&(col, _)| col);
756
757 for (col, block) in row_blocks {
759 new_data.push(block);
760 new_indices.push(vec![col]);
761 current_nnz += 1;
762 }
763
764 new_indptr.push(current_nnz);
765 }
766
767 self.data = new_data;
768 self.indices = new_indices;
769 self.indptr = new_indptr;
770 }
771
772 fn sorted_indices(&self) -> Box<dyn SparseArray<T>> {
773 let mut result = self.clone();
774 result.sort_indices();
775 Box::new(result)
776 }
777
778 fn has_sorted_indices(&self) -> bool {
779 for block_row in 0..self.block_rows {
780 let mut prev_col = None;
781
782 for k in self.indptr[block_row]..self.indptr[block_row + 1] {
783 let col = self.indices[k][0];
784
785 if let Some(prev) = prev_col {
786 if col <= prev {
787 return false;
788 }
789 }
790
791 prev_col = Some(col);
792 }
793 }
794
795 true
796 }
797
798 fn sum(&self, axis: Option<usize>) -> SparseResult<SparseSum<T>> {
799 match axis {
800 None => {
801 let mut total = T::zero();
803
804 for block in &self.data {
805 for row in block {
806 for &val in row {
807 total += val;
808 }
809 }
810 }
811
812 Ok(SparseSum::Scalar(total))
813 }
814 Some(0) => {
815 let mut result = vec![T::zero(); self.cols];
817 let (r, c) = self.block_size;
818
819 for block_row in 0..self.block_rows {
820 for k in self.indptr[block_row]..self.indptr[block_row + 1] {
821 let block_col = self.indices[k][0];
822 let block = &self.data[k];
823
824 for block_row_data in block.iter().take(r) {
825 for (j, &value) in block_row_data.iter().enumerate().take(c) {
826 let col = block_col * c + j;
827 if col < self.cols {
828 result[col] += value;
829 }
830 }
831 }
832 }
833 }
834
835 let mut row_indices = Vec::new();
837 let mut col_indices = Vec::new();
838 let mut values = Vec::new();
839
840 for (j, &val) in result.iter().enumerate() {
841 if !val.is_zero() {
842 row_indices.push(0);
843 col_indices.push(j);
844 values.push(val);
845 }
846 }
847
848 match CooArray::from_triplets(
849 &row_indices,
850 &col_indices,
851 &values,
852 (1, self.cols),
853 false,
854 ) {
855 Ok(array) => Ok(SparseSum::SparseArray(Box::new(array))),
856 Err(e) => Err(e),
857 }
858 }
859 Some(1) => {
860 let mut result = vec![T::zero(); self.rows];
862 let (r, c) = self.block_size;
863
864 for block_row in 0..self.block_rows {
865 for k in self.indptr[block_row]..self.indptr[block_row + 1] {
866 let block = &self.data[k];
867
868 for (i, block_row_data) in block.iter().enumerate().take(r) {
869 let row = block_row * r + i;
870 if row < self.rows {
871 for &value in block_row_data.iter().take(c) {
872 result[row] += value;
873 }
874 }
875 }
876 }
877 }
878
879 let mut row_indices = Vec::new();
881 let mut col_indices = Vec::new();
882 let mut values = Vec::new();
883
884 for (i, &val) in result.iter().enumerate() {
885 if !val.is_zero() {
886 row_indices.push(i);
887 col_indices.push(0);
888 values.push(val);
889 }
890 }
891
892 match CooArray::from_triplets(
893 &row_indices,
894 &col_indices,
895 &values,
896 (self.rows, 1),
897 false,
898 ) {
899 Ok(array) => Ok(SparseSum::SparseArray(Box::new(array))),
900 Err(e) => Err(e),
901 }
902 }
903 _ => Err(SparseError::InvalidAxis),
904 }
905 }
906
907 fn max(&self) -> T {
908 let mut max_val = T::neg_infinity();
909
910 for block in &self.data {
911 for row in block {
912 for &val in row {
913 max_val = max_val.max(val);
914 }
915 }
916 }
917
918 if max_val == T::neg_infinity() {
920 T::zero()
921 } else {
922 max_val
923 }
924 }
925
926 fn min(&self) -> T {
927 let mut min_val = T::infinity();
928 let mut has_nonzero = false;
929
930 for block in &self.data {
931 for row in block {
932 for &val in row {
933 if !val.is_zero() {
934 has_nonzero = true;
935 min_val = min_val.min(val);
936 }
937 }
938 }
939 }
940
941 if !has_nonzero {
943 T::zero()
944 } else {
945 min_val
946 }
947 }
948
949 fn find(&self) -> (Array1<usize>, Array1<usize>, Array1<T>) {
950 let (row_indices, col_indices, values) = self.to_coo_internal();
951
952 (
953 Array1::from_vec(row_indices),
954 Array1::from_vec(col_indices),
955 Array1::from_vec(values),
956 )
957 }
958
959 fn slice(
960 &self,
961 row_range: (usize, usize),
962 col_range: (usize, usize),
963 ) -> SparseResult<Box<dyn SparseArray<T>>> {
964 let (start_row, end_row) = row_range;
965 let (start_col, end_col) = col_range;
966 let (rows, cols) = self.shape();
967
968 if start_row >= rows || end_row > rows || start_col >= cols || end_col > cols {
969 return Err(SparseError::IndexOutOfBounds {
970 index: (start_row.max(end_row), start_col.max(end_col)),
971 shape: (rows, cols),
972 });
973 }
974
975 if start_row >= end_row || start_col >= end_col {
976 return Err(SparseError::InvalidSliceRange);
977 }
978
979 let coo = self.to_coo()?;
981 coo.slice(row_range, col_range)?.to_bsr()
982 }
983
984 fn as_any(&self) -> &dyn std::any::Any {
985 self
986 }
987}
988
989impl<T> fmt::Display for BsrArray<T>
991where
992 T: Float
993 + Add<Output = T>
994 + Sub<Output = T>
995 + Mul<Output = T>
996 + Div<Output = T>
997 + Debug
998 + Copy
999 + 'static
1000 + std::ops::AddAssign,
1001{
1002 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1003 writeln!(
1004 f,
1005 "BsrArray of shape {:?} with {} stored elements",
1006 (self.rows, self.cols),
1007 self.nnz()
1008 )?;
1009 writeln!(f, "Block size: {:?}", self.block_size)?;
1010 writeln!(f, "Number of blocks: {}", self.data.len())?;
1011
1012 if self.data.len() <= 5 {
1013 for block_row in 0..self.block_rows {
1014 for k in self.indptr[block_row]..self.indptr[block_row + 1] {
1015 let block_col = self.indices[k][0];
1016 let block = &self.data[k];
1017
1018 writeln!(f, "Block at ({}, {}): ", block_row, block_col)?;
1019 for row in block {
1020 write!(f, " [")?;
1021 for (j, &val) in row.iter().enumerate() {
1022 if j > 0 {
1023 write!(f, ", ")?;
1024 }
1025 write!(f, "{:?}", val)?;
1026 }
1027 writeln!(f, "]")?;
1028 }
1029 }
1030 }
1031 } else {
1032 writeln!(f, "({} blocks total)", self.data.len())?;
1033 }
1034
1035 Ok(())
1036 }
1037}
1038
1039#[cfg(test)]
1040mod tests {
1041 use super::*;
1042
1043 #[test]
1044 fn test_bsr_array_create() {
1045 let block1 = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
1052 let block2 = vec![vec![5.0, 6.0], vec![7.0, 8.0]];
1053
1054 let data = vec![block1, block2];
1055 let indices = vec![vec![0], vec![1]];
1056 let indptr = vec![0, 1, 2];
1057
1058 let array = BsrArray::new(data, indices, indptr, (4, 4), (2, 2)).unwrap();
1059
1060 assert_eq!(array.shape(), (4, 4));
1061 assert_eq!(array.block_size, (2, 2));
1062 assert_eq!(array.nnz(), 8); assert_eq!(array.get(0, 0), 1.0);
1066 assert_eq!(array.get(0, 1), 2.0);
1067 assert_eq!(array.get(1, 0), 3.0);
1068 assert_eq!(array.get(1, 1), 4.0);
1069 assert_eq!(array.get(2, 2), 5.0);
1070 assert_eq!(array.get(2, 3), 6.0);
1071 assert_eq!(array.get(3, 2), 7.0);
1072 assert_eq!(array.get(3, 3), 8.0);
1073 assert_eq!(array.get(0, 2), 0.0); }
1075
1076 #[test]
1077 fn test_bsr_array_from_triplets() {
1078 let rows = vec![0, 0, 1, 1, 2, 2, 3, 3];
1080 let cols = vec![0, 1, 0, 1, 2, 3, 2, 3];
1081 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
1082 let shape = (4, 4);
1083 let block_size = (2, 2);
1084
1085 let array = BsrArray::from_triplets(&rows, &cols, &data, shape, block_size).unwrap();
1086
1087 assert_eq!(array.shape(), (4, 4));
1088 assert_eq!(array.block_size, (2, 2));
1089 assert_eq!(array.nnz(), 8);
1090
1091 assert_eq!(array.get(0, 0), 1.0);
1093 assert_eq!(array.get(0, 1), 2.0);
1094 assert_eq!(array.get(1, 0), 3.0);
1095 assert_eq!(array.get(1, 1), 4.0);
1096 assert_eq!(array.get(2, 2), 5.0);
1097 assert_eq!(array.get(2, 3), 6.0);
1098 assert_eq!(array.get(3, 2), 7.0);
1099 assert_eq!(array.get(3, 3), 8.0);
1100 assert_eq!(array.get(0, 2), 0.0); }
1102
1103 #[test]
1104 fn test_bsr_array_conversion() {
1105 let block1 = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
1107 let block2 = vec![vec![5.0, 6.0], vec![7.0, 8.0]];
1108
1109 let data = vec![block1, block2];
1110 let indices = vec![vec![0], vec![1]];
1111 let indptr = vec![0, 1, 2];
1112
1113 let array = BsrArray::new(data, indices, indptr, (4, 4), (2, 2)).unwrap();
1114
1115 let coo = array.to_coo().unwrap();
1117 assert_eq!(coo.shape(), (4, 4));
1118 assert_eq!(coo.nnz(), 8);
1119
1120 let csr = array.to_csr().unwrap();
1122 assert_eq!(csr.shape(), (4, 4));
1123 assert_eq!(csr.nnz(), 8);
1124
1125 let dense = array.to_array();
1127 let expected = Array2::from_shape_vec(
1128 (4, 4),
1129 vec![
1130 1.0, 2.0, 0.0, 0.0, 3.0, 4.0, 0.0, 0.0, 0.0, 0.0, 5.0, 6.0, 0.0, 0.0, 7.0, 8.0,
1131 ],
1132 )
1133 .unwrap();
1134 assert_eq!(dense, expected);
1135 }
1136
1137 #[test]
1138 fn test_bsr_array_operations() {
1139 let block1 = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
1141 let data1 = vec![block1];
1142 let indices1 = vec![vec![0]];
1143 let indptr1 = vec![0, 1];
1144 let array1 = BsrArray::new(data1, indices1, indptr1, (2, 2), (2, 2)).unwrap();
1145
1146 let block2 = vec![vec![5.0, 6.0], vec![7.0, 8.0]];
1147 let data2 = vec![block2];
1148 let indices2 = vec![vec![0]];
1149 let indptr2 = vec![0, 1];
1150 let array2 = BsrArray::new(data2, indices2, indptr2, (2, 2), (2, 2)).unwrap();
1151
1152 let sum = array1.add(&array2).unwrap();
1154 assert_eq!(sum.shape(), (2, 2));
1155 assert_eq!(sum.get(0, 0), 6.0); assert_eq!(sum.get(0, 1), 8.0); assert_eq!(sum.get(1, 0), 10.0); assert_eq!(sum.get(1, 1), 12.0); let product = array1.mul(&array2).unwrap();
1162 assert_eq!(product.shape(), (2, 2));
1163 assert_eq!(product.get(0, 0), 5.0); assert_eq!(product.get(0, 1), 12.0); assert_eq!(product.get(1, 0), 21.0); assert_eq!(product.get(1, 1), 32.0); let dot = array1.dot(&array2).unwrap();
1170 assert_eq!(dot.shape(), (2, 2));
1171 assert_eq!(dot.get(0, 0), 19.0); assert_eq!(dot.get(0, 1), 22.0); assert_eq!(dot.get(1, 0), 43.0); assert_eq!(dot.get(1, 1), 50.0); }
1176
1177 #[test]
1178 fn test_bsr_array_dot_vector() {
1179 let block1 = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
1181 let block2 = vec![vec![5.0, 6.0], vec![7.0, 8.0]];
1182
1183 let data = vec![block1, block2];
1184 let indices = vec![vec![0], vec![1]];
1185 let indptr = vec![0, 1, 2];
1186
1187 let array = BsrArray::new(data, indices, indptr, (4, 4), (2, 2)).unwrap();
1188
1189 let vector = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
1191
1192 let result = array.dot_vector(&vector.view()).unwrap();
1194
1195 let expected = Array1::from_vec(vec![5.0, 11.0, 39.0, 53.0]);
1199 assert_eq!(result, expected);
1200 }
1201
1202 #[test]
1203 fn test_bsr_array_sum() {
1204 let block1 = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
1206 let block2 = vec![vec![5.0, 6.0], vec![7.0, 8.0]];
1207
1208 let data = vec![block1, block2];
1209 let indices = vec![vec![0], vec![1]];
1210 let indptr = vec![0, 1, 2];
1211
1212 let array = BsrArray::new(data, indices, indptr, (4, 4), (2, 2)).unwrap();
1213
1214 if let SparseSum::Scalar(sum) = array.sum(None).unwrap() {
1216 assert_eq!(sum, 36.0); } else {
1218 panic!("Expected SparseSum::Scalar");
1219 }
1220
1221 if let SparseSum::SparseArray(row_sum) = array.sum(Some(0)).unwrap() {
1223 assert_eq!(row_sum.shape(), (1, 4));
1224 assert_eq!(row_sum.get(0, 0), 4.0); assert_eq!(row_sum.get(0, 1), 6.0); assert_eq!(row_sum.get(0, 2), 12.0); assert_eq!(row_sum.get(0, 3), 14.0); } else {
1229 panic!("Expected SparseSum::SparseArray");
1230 }
1231
1232 if let SparseSum::SparseArray(col_sum) = array.sum(Some(1)).unwrap() {
1234 assert_eq!(col_sum.shape(), (4, 1));
1235 assert_eq!(col_sum.get(0, 0), 3.0); assert_eq!(col_sum.get(1, 0), 7.0); assert_eq!(col_sum.get(2, 0), 11.0); assert_eq!(col_sum.get(3, 0), 15.0); } else {
1240 panic!("Expected SparseSum::SparseArray");
1241 }
1242 }
1243}