scirs2_sparse/
bsr_array.rs

1// BSR Array implementation
2//
3// This module provides the BSR (Block Sparse Row) array format,
4// which is efficient for matrices with block-structured sparsity patterns.
5
6use ndarray::{Array1, Array2, ArrayView1};
7use num_traits::Float;
8use std::fmt::{self, Debug};
9use std::ops::{Add, Div, Mul, Sub};
10
11use crate::coo_array::CooArray;
12use crate::csc_array::CscArray;
13use crate::csr_array::CsrArray;
14use crate::dia_array::DiaArray;
15use crate::dok_array::DokArray;
16use crate::error::{SparseError, SparseResult};
17use crate::lil_array::LilArray;
18use crate::sparray::{SparseArray, SparseSum};
19
20/// BSR Array format
21///
22/// The BSR (Block Sparse Row) format stores a sparse matrix as a sparse matrix
23/// of dense blocks. It's particularly efficient for matrices with block-structured
24/// sparsity patterns, such as those arising in finite element methods.
25///
26/// # Notes
27///
28/// - Very efficient for matrices with block structure
29/// - Fast matrix-vector products for block-structured matrices
30/// - Reduced indexing overhead compared to CSR for block-structured problems
31/// - Not efficient for general sparse matrices
32/// - Difficult to modify once constructed
33///
34#[derive(Clone)]
35pub struct BsrArray<T>
36where
37    T: Float
38        + Add<Output = T>
39        + Sub<Output = T>
40        + Mul<Output = T>
41        + Div<Output = T>
42        + Debug
43        + Copy
44        + 'static
45        + std::ops::AddAssign,
46{
47    /// Number of rows
48    rows: usize,
49    /// Number of columns
50    cols: usize,
51    /// Block size (r, c)
52    block_size: (usize, usize),
53    /// Number of block rows
54    block_rows: usize,
55    /// Number of block columns (needed for internal calculations)
56    _block_cols: usize,
57    /// Data array (blocks stored row by row)
58    data: Vec<Vec<Vec<T>>>,
59    /// Column indices for each block
60    indices: Vec<Vec<usize>>,
61    /// Row pointers (indptr)
62    indptr: Vec<usize>,
63}
64
65impl<T> BsrArray<T>
66where
67    T: Float
68        + Add<Output = T>
69        + Sub<Output = T>
70        + Mul<Output = T>
71        + Div<Output = T>
72        + Debug
73        + Copy
74        + 'static
75        + std::ops::AddAssign,
76{
77    /// Create a new BSR array from raw data
78    ///
79    /// # Arguments
80    ///
81    /// * `data` - Block data (blocks stored row by row)
82    /// * `indices` - Column indices for each block
83    /// * `indptr` - Row pointers
84    /// * `shape` - Tuple containing the array dimensions (rows, cols)
85    /// * `block_size` - Tuple containing the block dimensions (r, c)
86    ///
87    /// # Returns
88    ///
89    /// * A new BSR array
90    ///
91    /// # Examples
92    ///
93    /// ```
94    /// use scirs2_sparse::bsr_array::BsrArray;
95    /// use scirs2_sparse::sparray::SparseArray;
96    ///
97    /// // Create a 4x4 sparse array with 2x2 blocks
98    /// // [1 2 0 0]
99    /// // [3 4 0 0]
100    /// // [0 0 5 6]
101    /// // [0 0 7 8]
102    ///
103    /// let block1 = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
104    /// let block2 = vec![vec![5.0, 6.0], vec![7.0, 8.0]];
105    ///
106    /// let data = vec![block1, block2];
107    /// let indices = vec![vec![0], vec![1]];
108    /// let indptr = vec![0, 1, 2];
109    ///
110    /// let array = BsrArray::new(data, indices, indptr, (4, 4), (2, 2)).unwrap();
111    /// assert_eq!(array.shape(), (4, 4));
112    /// assert_eq!(array.nnz(), 8); // All elements in the blocks are non-zero
113    /// ```
114    pub fn new(
115        data: Vec<Vec<Vec<T>>>,
116        indices: Vec<Vec<usize>>,
117        indptr: Vec<usize>,
118        shape: (usize, usize),
119        block_size: (usize, usize),
120    ) -> SparseResult<Self> {
121        let (rows, cols) = shape;
122        let (r, c) = block_size;
123
124        if r == 0 || c == 0 {
125            return Err(SparseError::ValueError(
126                "Block dimensions must be positive".to_string(),
127            ));
128        }
129
130        // Calculate block dimensions
131        #[allow(clippy::manual_div_ceil)]
132        let block_rows = (rows + r - 1) / r; // Ceiling division
133        #[allow(clippy::manual_div_ceil)]
134        let _block_cols = (cols + c - 1) / c; // Ceiling division
135
136        // Validate input
137        if indptr.len() != block_rows + 1 {
138            return Err(SparseError::DimensionMismatch {
139                expected: block_rows + 1,
140                found: indptr.len(),
141            });
142        }
143
144        if data.len() != indptr[block_rows] {
145            return Err(SparseError::DimensionMismatch {
146                expected: indptr[block_rows],
147                found: data.len(),
148            });
149        }
150
151        if indices.len() != data.len() {
152            return Err(SparseError::DimensionMismatch {
153                expected: data.len(),
154                found: indices.len(),
155            });
156        }
157
158        for block in data.iter() {
159            if block.len() != r {
160                return Err(SparseError::DimensionMismatch {
161                    expected: r,
162                    found: block.len(),
163                });
164            }
165
166            for row in block.iter() {
167                if row.len() != c {
168                    return Err(SparseError::DimensionMismatch {
169                        expected: c,
170                        found: row.len(),
171                    });
172                }
173            }
174        }
175
176        for idx_vec in indices.iter() {
177            if idx_vec.len() != 1 {
178                return Err(SparseError::ValueError(
179                    "Each index vector must contain exactly one block column index".to_string(),
180                ));
181            }
182            if idx_vec[0] >= _block_cols {
183                return Err(SparseError::ValueError(format!(
184                    "index {} out of bounds (max {})",
185                    idx_vec[0],
186                    _block_cols - 1
187                )));
188            }
189        }
190
191        Ok(BsrArray {
192            rows,
193            cols,
194            block_size,
195            block_rows,
196            _block_cols,
197            data,
198            indices,
199            indptr,
200        })
201    }
202
203    /// Create a new empty BSR array
204    ///
205    /// # Arguments
206    ///
207    /// * `shape` - Tuple containing the array dimensions (rows, cols)
208    /// * `block_size` - Tuple containing the block dimensions (r, c)
209    ///
210    /// # Returns
211    ///
212    /// * A new empty BSR array
213    pub fn empty(shape: (usize, usize), block_size: (usize, usize)) -> SparseResult<Self> {
214        let (rows, cols) = shape;
215        let (r, c) = block_size;
216
217        if r == 0 || c == 0 {
218            return Err(SparseError::ValueError(
219                "Block dimensions must be positive".to_string(),
220            ));
221        }
222
223        // Calculate block dimensions
224        #[allow(clippy::manual_div_ceil)]
225        let block_rows = (rows + r - 1) / r; // Ceiling division
226        #[allow(clippy::manual_div_ceil)]
227        let _block_cols = (cols + c - 1) / c; // Ceiling division
228
229        // Initialize empty BSR array
230        let data = Vec::new();
231        let indices = Vec::new();
232        let indptr = vec![0; block_rows + 1];
233
234        Ok(BsrArray {
235            rows,
236            cols,
237            block_size,
238            block_rows,
239            _block_cols,
240            data,
241            indices,
242            indptr,
243        })
244    }
245
246    /// Convert triplets to BSR format
247    ///
248    /// # Arguments
249    ///
250    /// * `row` - Row indices
251    /// * `col` - Column indices
252    /// * `data` - Data values
253    /// * `shape` - Shape of the array
254    /// * `block_size` - Size of the blocks
255    ///
256    /// # Returns
257    ///
258    /// * A new BSR array
259    pub fn from_triplets(
260        row: &[usize],
261        col: &[usize],
262        data: &[T],
263        shape: (usize, usize),
264        block_size: (usize, usize),
265    ) -> SparseResult<Self> {
266        if row.len() != col.len() || row.len() != data.len() {
267            return Err(SparseError::InconsistentData {
268                reason: "Lengths of row, col, and data arrays must be equal".to_string(),
269            });
270        }
271
272        let (rows, cols) = shape;
273        let (r, c) = block_size;
274
275        if r == 0 || c == 0 {
276            return Err(SparseError::ValueError(
277                "Block dimensions must be positive".to_string(),
278            ));
279        }
280
281        // Calculate block dimensions
282        #[allow(clippy::manual_div_ceil)]
283        let block_rows = (rows + r - 1) / r; // Ceiling division
284        #[allow(clippy::manual_div_ceil)]
285        let _block_cols = (cols + c - 1) / c; // Ceiling division
286
287        // First, we'll construct a temporary DOK-like representation for the blocks
288        let mut block_data = std::collections::HashMap::new();
289
290        // Assign each element to its corresponding block
291        for (&row_idx, (&col_idx, &val)) in row.iter().zip(col.iter().zip(data.iter())) {
292            if row_idx >= rows || col_idx >= cols {
293                return Err(SparseError::IndexOutOfBounds {
294                    index: (row_idx, col_idx),
295                    shape,
296                });
297            }
298
299            // Calculate block indices
300            let block_row = row_idx / r;
301            let block_col = col_idx / c;
302
303            // Calculate position within block
304            let block_row_pos = row_idx % r;
305            let block_col_pos = col_idx % c;
306
307            // Create or get the block
308            let block = block_data.entry((block_row, block_col)).or_insert_with(|| {
309                let block = vec![vec![T::zero(); c]; r];
310                block
311            });
312
313            // Set the value in the block
314            block[block_row_pos][block_col_pos] = val;
315        }
316
317        // Now convert the DOK-like format to BSR
318        let mut rows_with_blocks: Vec<usize> = block_data.keys().map(|&(row, _)| row).collect();
319        rows_with_blocks.sort();
320        rows_with_blocks.dedup();
321
322        // Create indptr array
323        let mut indptr = vec![0; block_rows + 1];
324        let mut current_nnz = 0;
325
326        // Sorted blocks data and indices
327        let mut data = Vec::new();
328        let mut indices = Vec::new();
329
330        for row_idx in 0..block_rows {
331            if rows_with_blocks.contains(&row_idx) {
332                // Get all blocks for this row
333                let mut row_blocks: Vec<(usize, Vec<Vec<T>>)> = block_data
334                    .iter()
335                    .filter(|&(&(r, _), _)| r == row_idx)
336                    .map(|(&(_, c), block)| (c, block.clone()))
337                    .collect();
338
339                // Sort by column index
340                row_blocks.sort_by_key(|&(col, _)| col);
341
342                // Add to data and indices
343                for (col, block) in row_blocks {
344                    data.push(block);
345                    indices.push(vec![col]);
346                    current_nnz += 1;
347                }
348            }
349
350            indptr[row_idx + 1] = current_nnz;
351        }
352
353        // Create the BSR array
354        BsrArray::new(data, indices, indptr, shape, block_size)
355    }
356
357    /// Convert to COO format triplets
358    fn to_coo_internal(&self) -> (Vec<usize>, Vec<usize>, Vec<T>) {
359        let (r, c) = self.block_size;
360        let mut row_indices = Vec::new();
361        let mut col_indices = Vec::new();
362        let mut values = Vec::new();
363
364        for block_row in 0..self.block_rows {
365            for k in self.indptr[block_row]..self.indptr[block_row + 1] {
366                let block_col = self.indices[k][0];
367                let block = &self.data[k];
368
369                // For each element in the block
370                for (i, block_row_data) in block.iter().enumerate().take(r) {
371                    let row = block_row * r + i;
372                    if row < self.rows {
373                        for (j, &value) in block_row_data.iter().enumerate().take(c) {
374                            let col = block_col * c + j;
375                            if col < self.cols && !value.is_zero() {
376                                row_indices.push(row);
377                                col_indices.push(col);
378                                values.push(value);
379                            }
380                        }
381                    }
382                }
383            }
384        }
385
386        (row_indices, col_indices, values)
387    }
388}
389
390impl<T> SparseArray<T> for BsrArray<T>
391where
392    T: Float
393        + Add<Output = T>
394        + Sub<Output = T>
395        + Mul<Output = T>
396        + Div<Output = T>
397        + Debug
398        + Copy
399        + 'static
400        + std::ops::AddAssign,
401{
402    fn shape(&self) -> (usize, usize) {
403        (self.rows, self.cols)
404    }
405
406    fn nnz(&self) -> usize {
407        let mut count = 0;
408
409        for block in &self.data {
410            for row in block {
411                for &val in row {
412                    if !val.is_zero() {
413                        count += 1;
414                    }
415                }
416            }
417        }
418
419        count
420    }
421
422    fn dtype(&self) -> &str {
423        "float" // Placeholder; ideally would return the actual type
424    }
425
426    fn to_array(&self) -> Array2<T> {
427        let mut result = Array2::zeros((self.rows, self.cols));
428        let (r, c) = self.block_size;
429
430        for block_row in 0..self.block_rows {
431            for k in self.indptr[block_row]..self.indptr[block_row + 1] {
432                let block_col = self.indices[k][0];
433                let block = &self.data[k];
434
435                // Copy block to dense array
436                for (i, block_row_data) in block.iter().enumerate().take(r) {
437                    let row = block_row * r + i;
438                    if row < self.rows {
439                        for (j, &value) in block_row_data.iter().enumerate().take(c) {
440                            let col = block_col * c + j;
441                            if col < self.cols {
442                                result[[row, col]] = value;
443                            }
444                        }
445                    }
446                }
447            }
448        }
449
450        result
451    }
452
453    fn toarray(&self) -> Array2<T> {
454        self.to_array()
455    }
456
457    fn to_coo(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
458        let (row_indices, col_indices, values) = self.to_coo_internal();
459        CooArray::from_triplets(
460            &row_indices,
461            &col_indices,
462            &values,
463            (self.rows, self.cols),
464            false,
465        )
466        .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
467    }
468
469    fn to_csr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
470        let (row_indices, col_indices, values) = self.to_coo_internal();
471        CsrArray::from_triplets(
472            &row_indices,
473            &col_indices,
474            &values,
475            (self.rows, self.cols),
476            false,
477        )
478        .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
479    }
480
481    fn to_csc(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
482        let (row_indices, col_indices, values) = self.to_coo_internal();
483        CscArray::from_triplets(
484            &row_indices,
485            &col_indices,
486            &values,
487            (self.rows, self.cols),
488            false,
489        )
490        .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
491    }
492
493    fn to_dok(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
494        let (row_indices, col_indices, values) = self.to_coo_internal();
495        DokArray::from_triplets(&row_indices, &col_indices, &values, (self.rows, self.cols))
496            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
497    }
498
499    fn to_lil(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
500        let (row_indices, col_indices, values) = self.to_coo_internal();
501        LilArray::from_triplets(&row_indices, &col_indices, &values, (self.rows, self.cols))
502            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
503    }
504
505    fn to_dia(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
506        let (row_indices, col_indices, values) = self.to_coo_internal();
507        DiaArray::from_triplets(&row_indices, &col_indices, &values, (self.rows, self.cols))
508            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
509    }
510
511    fn to_bsr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
512        Ok(Box::new(self.clone()))
513    }
514
515    fn add(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
516        // For efficiency, convert both to CSR for addition
517        let csr_self = self.to_csr()?;
518        let csr_other = other.to_csr()?;
519        csr_self.add(&*csr_other)
520    }
521
522    fn sub(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
523        // For efficiency, convert both to CSR for subtraction
524        let csr_self = self.to_csr()?;
525        let csr_other = other.to_csr()?;
526        csr_self.sub(&*csr_other)
527    }
528
529    fn mul(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
530        // For efficiency, convert both to CSR for element-wise multiplication
531        let csr_self = self.to_csr()?;
532        let csr_other = other.to_csr()?;
533        csr_self.mul(&*csr_other)
534    }
535
536    fn div(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
537        // For efficiency, convert both to CSR for element-wise division
538        let csr_self = self.to_csr()?;
539        let csr_other = other.to_csr()?;
540        csr_self.div(&*csr_other)
541    }
542
543    fn dot(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
544        let (_, n) = self.shape();
545        let (p, q) = other.shape();
546
547        if n != p {
548            return Err(SparseError::DimensionMismatch {
549                expected: n,
550                found: p,
551            });
552        }
553
554        // If other is a vector (thin matrix), we can use optimized BSR-Vector multiplication
555        if q == 1 {
556            // Get the vector from other
557            let other_array = other.to_array();
558            let vec_view = other_array.column(0);
559
560            // Perform BSR-Vector multiplication
561            let result = self.dot_vector(&vec_view)?;
562
563            // Convert to a matrix - create a COO from triplets
564            let mut rows = Vec::new();
565            let mut cols = Vec::new();
566            let mut values = Vec::new();
567
568            for (i, &val) in result.iter().enumerate() {
569                if !val.is_zero() {
570                    rows.push(i);
571                    cols.push(0);
572                    values.push(val);
573                }
574            }
575
576            CooArray::from_triplets(&rows, &cols, &values, (result.len(), 1), false)
577                .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
578        } else {
579            // For general matrix-matrix multiplication, convert to CSR
580            let csr_self = self.to_csr()?;
581            csr_self.dot(other)
582        }
583    }
584
585    fn dot_vector(&self, other: &ArrayView1<T>) -> SparseResult<Array1<T>> {
586        let (rows, cols) = self.shape();
587        let (r, c) = self.block_size;
588
589        if cols != other.len() {
590            return Err(SparseError::DimensionMismatch {
591                expected: cols,
592                found: other.len(),
593            });
594        }
595
596        let mut result = Array1::zeros(rows);
597
598        for block_row in 0..self.block_rows {
599            for k in self.indptr[block_row]..self.indptr[block_row + 1] {
600                let block_col = self.indices[k][0];
601                let block = &self.data[k];
602
603                // For each element in the block
604                for (i, block_row_data) in block.iter().enumerate().take(r) {
605                    let row = block_row * r + i;
606                    if row < self.rows {
607                        for (j, &value) in block_row_data.iter().enumerate().take(c) {
608                            let col = block_col * c + j;
609                            if col < self.cols {
610                                result[row] += value * other[col];
611                            }
612                        }
613                    }
614                }
615            }
616        }
617
618        Ok(result)
619    }
620
621    fn transpose(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
622        // For efficiency, convert to COO, transpose, then convert back to BSR
623        self.to_coo()?.transpose()?.to_bsr()
624    }
625
626    fn copy(&self) -> Box<dyn SparseArray<T>> {
627        Box::new(self.clone())
628    }
629
630    fn get(&self, i: usize, j: usize) -> T {
631        if i >= self.rows || j >= self.cols {
632            return T::zero();
633        }
634
635        let (r, c) = self.block_size;
636        let block_row = i / r;
637        let block_col = j / c;
638        let block_row_pos = i % r;
639        let block_col_pos = j % c;
640
641        // Search for the block in the row
642        for k in self.indptr[block_row]..self.indptr[block_row + 1] {
643            if self.indices[k][0] == block_col {
644                return self.data[k][block_row_pos][block_col_pos];
645            }
646        }
647
648        T::zero()
649    }
650
651    fn set(&mut self, i: usize, j: usize, value: T) -> SparseResult<()> {
652        if i >= self.rows || j >= self.cols {
653            return Err(SparseError::IndexOutOfBounds {
654                index: (i, j),
655                shape: (self.rows, self.cols),
656            });
657        }
658
659        let (r, c) = self.block_size;
660        let block_row = i / r;
661        let block_col = j / c;
662        let block_row_pos = i % r;
663        let block_col_pos = j % c;
664
665        // Search for the block in the row
666        for k in self.indptr[block_row]..self.indptr[block_row + 1] {
667            if self.indices[k][0] == block_col {
668                // Block exists, update value
669                self.data[k][block_row_pos][block_col_pos] = value;
670                return Ok(());
671            }
672        }
673
674        // Block doesn't exist, we need to create it
675        if !value.is_zero() {
676            // Find position to insert
677            let pos = self.indptr[block_row + 1];
678
679            // Create new block
680            let mut block = vec![vec![T::zero(); c]; r];
681            block[block_row_pos][block_col_pos] = value;
682
683            // Insert block, indices
684            self.data.insert(pos, block);
685            self.indices.insert(pos, vec![block_col]);
686
687            // Update indptr for subsequent rows
688            for k in (block_row + 1)..=self.block_rows {
689                self.indptr[k] += 1;
690            }
691
692            Ok(())
693        } else {
694            // If value is zero and block doesn't exist, do nothing
695            Ok(())
696        }
697    }
698
699    fn eliminate_zeros(&mut self) {
700        // No need to use block_size variables here
701        let mut new_data = Vec::new();
702        let mut new_indices = Vec::new();
703        let mut new_indptr = vec![0];
704        let mut current_nnz = 0;
705
706        for block_row in 0..self.block_rows {
707            for k in self.indptr[block_row]..self.indptr[block_row + 1] {
708                let block_col = self.indices[k][0];
709                let block = &self.data[k];
710
711                // Check if block has any non-zero elements
712                let mut has_nonzero = false;
713                for row in block {
714                    for &val in row {
715                        if !val.is_zero() {
716                            has_nonzero = true;
717                            break;
718                        }
719                    }
720                    if has_nonzero {
721                        break;
722                    }
723                }
724
725                if has_nonzero {
726                    new_data.push(block.clone());
727                    new_indices.push(vec![block_col]);
728                    current_nnz += 1;
729                }
730            }
731
732            new_indptr.push(current_nnz);
733        }
734
735        self.data = new_data;
736        self.indices = new_indices;
737        self.indptr = new_indptr;
738    }
739
740    fn sort_indices(&mut self) {
741        // No need to use block_size variables here
742        let mut new_data = Vec::new();
743        let mut new_indices = Vec::new();
744        let mut new_indptr = vec![0];
745        let mut current_nnz = 0;
746
747        for block_row in 0..self.block_rows {
748            // Get blocks for this row
749            let mut row_blocks = Vec::new();
750            for k in self.indptr[block_row]..self.indptr[block_row + 1] {
751                row_blocks.push((self.indices[k][0], self.data[k].clone()));
752            }
753
754            // Sort by column index
755            row_blocks.sort_by_key(|&(col, _)| col);
756
757            // Add sorted blocks to new data structures
758            for (col, block) in row_blocks {
759                new_data.push(block);
760                new_indices.push(vec![col]);
761                current_nnz += 1;
762            }
763
764            new_indptr.push(current_nnz);
765        }
766
767        self.data = new_data;
768        self.indices = new_indices;
769        self.indptr = new_indptr;
770    }
771
772    fn sorted_indices(&self) -> Box<dyn SparseArray<T>> {
773        let mut result = self.clone();
774        result.sort_indices();
775        Box::new(result)
776    }
777
778    fn has_sorted_indices(&self) -> bool {
779        for block_row in 0..self.block_rows {
780            let mut prev_col = None;
781
782            for k in self.indptr[block_row]..self.indptr[block_row + 1] {
783                let col = self.indices[k][0];
784
785                if let Some(prev) = prev_col {
786                    if col <= prev {
787                        return false;
788                    }
789                }
790
791                prev_col = Some(col);
792            }
793        }
794
795        true
796    }
797
798    fn sum(&self, axis: Option<usize>) -> SparseResult<SparseSum<T>> {
799        match axis {
800            None => {
801                // Sum all elements
802                let mut total = T::zero();
803
804                for block in &self.data {
805                    for row in block {
806                        for &val in row {
807                            total += val;
808                        }
809                    }
810                }
811
812                Ok(SparseSum::Scalar(total))
813            }
814            Some(0) => {
815                // Sum along rows (result is 1 x cols)
816                let mut result = vec![T::zero(); self.cols];
817                let (r, c) = self.block_size;
818
819                for block_row in 0..self.block_rows {
820                    for k in self.indptr[block_row]..self.indptr[block_row + 1] {
821                        let block_col = self.indices[k][0];
822                        let block = &self.data[k];
823
824                        for block_row_data in block.iter().take(r) {
825                            for (j, &value) in block_row_data.iter().enumerate().take(c) {
826                                let col = block_col * c + j;
827                                if col < self.cols {
828                                    result[col] += value;
829                                }
830                            }
831                        }
832                    }
833                }
834
835                // Create a sparse array from the result
836                let mut row_indices = Vec::new();
837                let mut col_indices = Vec::new();
838                let mut values = Vec::new();
839
840                for (j, &val) in result.iter().enumerate() {
841                    if !val.is_zero() {
842                        row_indices.push(0);
843                        col_indices.push(j);
844                        values.push(val);
845                    }
846                }
847
848                match CooArray::from_triplets(
849                    &row_indices,
850                    &col_indices,
851                    &values,
852                    (1, self.cols),
853                    false,
854                ) {
855                    Ok(array) => Ok(SparseSum::SparseArray(Box::new(array))),
856                    Err(e) => Err(e),
857                }
858            }
859            Some(1) => {
860                // Sum along columns (result is rows x 1)
861                let mut result = vec![T::zero(); self.rows];
862                let (r, c) = self.block_size;
863
864                for block_row in 0..self.block_rows {
865                    for k in self.indptr[block_row]..self.indptr[block_row + 1] {
866                        let block = &self.data[k];
867
868                        for (i, block_row_data) in block.iter().enumerate().take(r) {
869                            let row = block_row * r + i;
870                            if row < self.rows {
871                                for &value in block_row_data.iter().take(c) {
872                                    result[row] += value;
873                                }
874                            }
875                        }
876                    }
877                }
878
879                // Create a sparse array from the result
880                let mut row_indices = Vec::new();
881                let mut col_indices = Vec::new();
882                let mut values = Vec::new();
883
884                for (i, &val) in result.iter().enumerate() {
885                    if !val.is_zero() {
886                        row_indices.push(i);
887                        col_indices.push(0);
888                        values.push(val);
889                    }
890                }
891
892                match CooArray::from_triplets(
893                    &row_indices,
894                    &col_indices,
895                    &values,
896                    (self.rows, 1),
897                    false,
898                ) {
899                    Ok(array) => Ok(SparseSum::SparseArray(Box::new(array))),
900                    Err(e) => Err(e),
901                }
902            }
903            _ => Err(SparseError::InvalidAxis),
904        }
905    }
906
907    fn max(&self) -> T {
908        let mut max_val = T::neg_infinity();
909
910        for block in &self.data {
911            for row in block {
912                for &val in row {
913                    max_val = max_val.max(val);
914                }
915            }
916        }
917
918        // If no elements or all negative infinity, return zero
919        if max_val == T::neg_infinity() {
920            T::zero()
921        } else {
922            max_val
923        }
924    }
925
926    fn min(&self) -> T {
927        let mut min_val = T::infinity();
928        let mut has_nonzero = false;
929
930        for block in &self.data {
931            for row in block {
932                for &val in row {
933                    if !val.is_zero() {
934                        has_nonzero = true;
935                        min_val = min_val.min(val);
936                    }
937                }
938            }
939        }
940
941        // If no non-zero elements, return zero
942        if !has_nonzero {
943            T::zero()
944        } else {
945            min_val
946        }
947    }
948
949    fn find(&self) -> (Array1<usize>, Array1<usize>, Array1<T>) {
950        let (row_indices, col_indices, values) = self.to_coo_internal();
951
952        (
953            Array1::from_vec(row_indices),
954            Array1::from_vec(col_indices),
955            Array1::from_vec(values),
956        )
957    }
958
959    fn slice(
960        &self,
961        row_range: (usize, usize),
962        col_range: (usize, usize),
963    ) -> SparseResult<Box<dyn SparseArray<T>>> {
964        let (start_row, end_row) = row_range;
965        let (start_col, end_col) = col_range;
966        let (rows, cols) = self.shape();
967
968        if start_row >= rows || end_row > rows || start_col >= cols || end_col > cols {
969            return Err(SparseError::IndexOutOfBounds {
970                index: (start_row.max(end_row), start_col.max(end_col)),
971                shape: (rows, cols),
972            });
973        }
974
975        if start_row >= end_row || start_col >= end_col {
976            return Err(SparseError::InvalidSliceRange);
977        }
978
979        // Convert to COO, slice, then convert back to BSR
980        let coo = self.to_coo()?;
981        coo.slice(row_range, col_range)?.to_bsr()
982    }
983
984    fn as_any(&self) -> &dyn std::any::Any {
985        self
986    }
987}
988
989// Implement Display for BsrArray for better debugging
990impl<T> fmt::Display for BsrArray<T>
991where
992    T: Float
993        + Add<Output = T>
994        + Sub<Output = T>
995        + Mul<Output = T>
996        + Div<Output = T>
997        + Debug
998        + Copy
999        + 'static
1000        + std::ops::AddAssign,
1001{
1002    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1003        writeln!(
1004            f,
1005            "BsrArray of shape {:?} with {} stored elements",
1006            (self.rows, self.cols),
1007            self.nnz()
1008        )?;
1009        writeln!(f, "Block size: {:?}", self.block_size)?;
1010        writeln!(f, "Number of blocks: {}", self.data.len())?;
1011
1012        if self.data.len() <= 5 {
1013            for block_row in 0..self.block_rows {
1014                for k in self.indptr[block_row]..self.indptr[block_row + 1] {
1015                    let block_col = self.indices[k][0];
1016                    let block = &self.data[k];
1017
1018                    writeln!(f, "Block at ({}, {}): ", block_row, block_col)?;
1019                    for row in block {
1020                        write!(f, "  [")?;
1021                        for (j, &val) in row.iter().enumerate() {
1022                            if j > 0 {
1023                                write!(f, ", ")?;
1024                            }
1025                            write!(f, "{:?}", val)?;
1026                        }
1027                        writeln!(f, "]")?;
1028                    }
1029                }
1030            }
1031        } else {
1032            writeln!(f, "({} blocks total)", self.data.len())?;
1033        }
1034
1035        Ok(())
1036    }
1037}
1038
1039#[cfg(test)]
1040mod tests {
1041    use super::*;
1042
1043    #[test]
1044    fn test_bsr_array_create() {
1045        // Create a 4x4 sparse array with 2x2 blocks
1046        // [1 2 0 0]
1047        // [3 4 0 0]
1048        // [0 0 5 6]
1049        // [0 0 7 8]
1050
1051        let block1 = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
1052        let block2 = vec![vec![5.0, 6.0], vec![7.0, 8.0]];
1053
1054        let data = vec![block1, block2];
1055        let indices = vec![vec![0], vec![1]];
1056        let indptr = vec![0, 1, 2];
1057
1058        let array = BsrArray::new(data, indices, indptr, (4, 4), (2, 2)).unwrap();
1059
1060        assert_eq!(array.shape(), (4, 4));
1061        assert_eq!(array.block_size, (2, 2));
1062        assert_eq!(array.nnz(), 8); // All elements in the blocks are non-zero
1063
1064        // Test values
1065        assert_eq!(array.get(0, 0), 1.0);
1066        assert_eq!(array.get(0, 1), 2.0);
1067        assert_eq!(array.get(1, 0), 3.0);
1068        assert_eq!(array.get(1, 1), 4.0);
1069        assert_eq!(array.get(2, 2), 5.0);
1070        assert_eq!(array.get(2, 3), 6.0);
1071        assert_eq!(array.get(3, 2), 7.0);
1072        assert_eq!(array.get(3, 3), 8.0);
1073        assert_eq!(array.get(0, 2), 0.0); // zero element
1074    }
1075
1076    #[test]
1077    fn test_bsr_array_from_triplets() {
1078        // Create a 4x4 sparse array with 2x2 blocks
1079        let rows = vec![0, 0, 1, 1, 2, 2, 3, 3];
1080        let cols = vec![0, 1, 0, 1, 2, 3, 2, 3];
1081        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
1082        let shape = (4, 4);
1083        let block_size = (2, 2);
1084
1085        let array = BsrArray::from_triplets(&rows, &cols, &data, shape, block_size).unwrap();
1086
1087        assert_eq!(array.shape(), (4, 4));
1088        assert_eq!(array.block_size, (2, 2));
1089        assert_eq!(array.nnz(), 8);
1090
1091        // Test values
1092        assert_eq!(array.get(0, 0), 1.0);
1093        assert_eq!(array.get(0, 1), 2.0);
1094        assert_eq!(array.get(1, 0), 3.0);
1095        assert_eq!(array.get(1, 1), 4.0);
1096        assert_eq!(array.get(2, 2), 5.0);
1097        assert_eq!(array.get(2, 3), 6.0);
1098        assert_eq!(array.get(3, 2), 7.0);
1099        assert_eq!(array.get(3, 3), 8.0);
1100        assert_eq!(array.get(0, 2), 0.0); // zero element
1101    }
1102
1103    #[test]
1104    fn test_bsr_array_conversion() {
1105        // Create a 4x4 sparse array with 2x2 blocks
1106        let block1 = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
1107        let block2 = vec![vec![5.0, 6.0], vec![7.0, 8.0]];
1108
1109        let data = vec![block1, block2];
1110        let indices = vec![vec![0], vec![1]];
1111        let indptr = vec![0, 1, 2];
1112
1113        let array = BsrArray::new(data, indices, indptr, (4, 4), (2, 2)).unwrap();
1114
1115        // Convert to COO and check
1116        let coo = array.to_coo().unwrap();
1117        assert_eq!(coo.shape(), (4, 4));
1118        assert_eq!(coo.nnz(), 8);
1119
1120        // Convert to CSR and check
1121        let csr = array.to_csr().unwrap();
1122        assert_eq!(csr.shape(), (4, 4));
1123        assert_eq!(csr.nnz(), 8);
1124
1125        // Convert to dense and check
1126        let dense = array.to_array();
1127        let expected = Array2::from_shape_vec(
1128            (4, 4),
1129            vec![
1130                1.0, 2.0, 0.0, 0.0, 3.0, 4.0, 0.0, 0.0, 0.0, 0.0, 5.0, 6.0, 0.0, 0.0, 7.0, 8.0,
1131            ],
1132        )
1133        .unwrap();
1134        assert_eq!(dense, expected);
1135    }
1136
1137    #[test]
1138    fn test_bsr_array_operations() {
1139        // Create two simple block arrays
1140        let block1 = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
1141        let data1 = vec![block1];
1142        let indices1 = vec![vec![0]];
1143        let indptr1 = vec![0, 1];
1144        let array1 = BsrArray::new(data1, indices1, indptr1, (2, 2), (2, 2)).unwrap();
1145
1146        let block2 = vec![vec![5.0, 6.0], vec![7.0, 8.0]];
1147        let data2 = vec![block2];
1148        let indices2 = vec![vec![0]];
1149        let indptr2 = vec![0, 1];
1150        let array2 = BsrArray::new(data2, indices2, indptr2, (2, 2), (2, 2)).unwrap();
1151
1152        // Test addition
1153        let sum = array1.add(&array2).unwrap();
1154        assert_eq!(sum.shape(), (2, 2));
1155        assert_eq!(sum.get(0, 0), 6.0); // 1+5
1156        assert_eq!(sum.get(0, 1), 8.0); // 2+6
1157        assert_eq!(sum.get(1, 0), 10.0); // 3+7
1158        assert_eq!(sum.get(1, 1), 12.0); // 4+8
1159
1160        // Test element-wise multiplication
1161        let product = array1.mul(&array2).unwrap();
1162        assert_eq!(product.shape(), (2, 2));
1163        assert_eq!(product.get(0, 0), 5.0); // 1*5
1164        assert_eq!(product.get(0, 1), 12.0); // 2*6
1165        assert_eq!(product.get(1, 0), 21.0); // 3*7
1166        assert_eq!(product.get(1, 1), 32.0); // 4*8
1167
1168        // Test dot product (matrix multiplication)
1169        let dot = array1.dot(&array2).unwrap();
1170        assert_eq!(dot.shape(), (2, 2));
1171        assert_eq!(dot.get(0, 0), 19.0); // 1*5 + 2*7
1172        assert_eq!(dot.get(0, 1), 22.0); // 1*6 + 2*8
1173        assert_eq!(dot.get(1, 0), 43.0); // 3*5 + 4*7
1174        assert_eq!(dot.get(1, 1), 50.0); // 3*6 + 4*8
1175    }
1176
1177    #[test]
1178    fn test_bsr_array_dot_vector() {
1179        // Create a 4x4 sparse array with 2x2 blocks
1180        let block1 = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
1181        let block2 = vec![vec![5.0, 6.0], vec![7.0, 8.0]];
1182
1183        let data = vec![block1, block2];
1184        let indices = vec![vec![0], vec![1]];
1185        let indptr = vec![0, 1, 2];
1186
1187        let array = BsrArray::new(data, indices, indptr, (4, 4), (2, 2)).unwrap();
1188
1189        // Create a vector
1190        let vector = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
1191
1192        // Test matrix-vector multiplication
1193        let result = array.dot_vector(&vector.view()).unwrap();
1194
1195        // Expected: [1*1 + 2*2 + 0*3 + 0*4, 3*1 + 4*2 + 0*3 + 0*4,
1196        //            0*1 + 0*2 + 5*3 + 6*4, 0*1 + 0*2 + 7*3 + 8*4]
1197        // = [5, 11, 39, 53]
1198        let expected = Array1::from_vec(vec![5.0, 11.0, 39.0, 53.0]);
1199        assert_eq!(result, expected);
1200    }
1201
1202    #[test]
1203    fn test_bsr_array_sum() {
1204        // Create a 4x4 sparse array with 2x2 blocks
1205        let block1 = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
1206        let block2 = vec![vec![5.0, 6.0], vec![7.0, 8.0]];
1207
1208        let data = vec![block1, block2];
1209        let indices = vec![vec![0], vec![1]];
1210        let indptr = vec![0, 1, 2];
1211
1212        let array = BsrArray::new(data, indices, indptr, (4, 4), (2, 2)).unwrap();
1213
1214        // Test sum of entire array
1215        if let SparseSum::Scalar(sum) = array.sum(None).unwrap() {
1216            assert_eq!(sum, 36.0); // 1+2+3+4+5+6+7+8 = 36
1217        } else {
1218            panic!("Expected SparseSum::Scalar");
1219        }
1220
1221        // Test sum along rows (result should be 1 x 4)
1222        if let SparseSum::SparseArray(row_sum) = array.sum(Some(0)).unwrap() {
1223            assert_eq!(row_sum.shape(), (1, 4));
1224            assert_eq!(row_sum.get(0, 0), 4.0); // 1+3
1225            assert_eq!(row_sum.get(0, 1), 6.0); // 2+4
1226            assert_eq!(row_sum.get(0, 2), 12.0); // 5+7
1227            assert_eq!(row_sum.get(0, 3), 14.0); // 6+8
1228        } else {
1229            panic!("Expected SparseSum::SparseArray");
1230        }
1231
1232        // Test sum along columns (result should be 4 x 1)
1233        if let SparseSum::SparseArray(col_sum) = array.sum(Some(1)).unwrap() {
1234            assert_eq!(col_sum.shape(), (4, 1));
1235            assert_eq!(col_sum.get(0, 0), 3.0); // 1+2
1236            assert_eq!(col_sum.get(1, 0), 7.0); // 3+4
1237            assert_eq!(col_sum.get(2, 0), 11.0); // 5+6
1238            assert_eq!(col_sum.get(3, 0), 15.0); // 7+8
1239        } else {
1240            panic!("Expected SparseSum::SparseArray");
1241        }
1242    }
1243}