scirs2_sparse/
bsr_array.rs

1// BSR Array implementation
2//
3// This module provides the BSR (Block Sparse Row) array format,
4// which is efficient for matrices with block-structured sparsity patterns.
5
6use ndarray::{Array1, Array2, ArrayView1};
7use num_traits::Float;
8use std::fmt::{self, Debug};
9use std::ops::{Add, Div, Mul, Sub};
10
11use crate::coo_array::CooArray;
12use crate::csc_array::CscArray;
13use crate::csr_array::CsrArray;
14use crate::dia_array::DiaArray;
15use crate::dok_array::DokArray;
16use crate::error::{SparseError, SparseResult};
17use crate::lil_array::LilArray;
18use crate::sparray::{SparseArray, SparseSum};
19
20/// BSR Array format
21///
22/// The BSR (Block Sparse Row) format stores a sparse matrix as a sparse matrix
23/// of dense blocks. It's particularly efficient for matrices with block-structured
24/// sparsity patterns, such as those arising in finite element methods.
25///
26/// # Notes
27///
28/// - Very efficient for matrices with block structure
29/// - Fast matrix-vector products for block-structured matrices
30/// - Reduced indexing overhead compared to CSR for block-structured problems
31/// - Not efficient for general sparse matrices
32/// - Difficult to modify once constructed
33///
34#[derive(Clone)]
35pub struct BsrArray<T>
36where
37    T: Float
38        + Add<Output = T>
39        + Sub<Output = T>
40        + Mul<Output = T>
41        + Div<Output = T>
42        + Debug
43        + Copy
44        + 'static
45        + std::ops::AddAssign,
46{
47    /// Number of rows
48    rows: usize,
49    /// Number of columns
50    cols: usize,
51    /// Block size (r, c)
52    block_size: (usize, usize),
53    /// Number of block rows
54    block_rows: usize,
55    /// Number of block columns (needed for internal calculations)
56    #[allow(dead_code)]
57    block_cols: usize,
58    /// Data array (blocks stored row by row)
59    data: Vec<Vec<Vec<T>>>,
60    /// Column indices for each block
61    indices: Vec<Vec<usize>>,
62    /// Row pointers (indptr)
63    indptr: Vec<usize>,
64}
65
66impl<T> BsrArray<T>
67where
68    T: Float
69        + Add<Output = T>
70        + Sub<Output = T>
71        + Mul<Output = T>
72        + Div<Output = T>
73        + Debug
74        + Copy
75        + 'static
76        + std::ops::AddAssign,
77{
78    /// Create a new BSR array from raw data
79    ///
80    /// # Arguments
81    ///
82    /// * `data` - Block data (blocks stored row by row)
83    /// * `indices` - Column indices for each block
84    /// * `indptr` - Row pointers
85    /// * `shape` - Tuple containing the array dimensions (rows, cols)
86    /// * `block_size` - Tuple containing the block dimensions (r, c)
87    ///
88    /// # Returns
89    ///
90    /// * A new BSR array
91    ///
92    /// # Examples
93    ///
94    /// ```
95    /// use scirs2_sparse::bsr_array::BsrArray;
96    /// use scirs2_sparse::sparray::SparseArray;
97    ///
98    /// // Create a 4x4 sparse array with 2x2 blocks
99    /// // [1 2 0 0]
100    /// // [3 4 0 0]
101    /// // [0 0 5 6]
102    /// // [0 0 7 8]
103    ///
104    /// let block1 = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
105    /// let block2 = vec![vec![5.0, 6.0], vec![7.0, 8.0]];
106    ///
107    /// let data = vec![block1, block2];
108    /// let indices = vec![vec![0], vec![1]];
109    /// let indptr = vec![0, 1, 2];
110    ///
111    /// let array = BsrArray::new(data, indices, indptr, (4, 4), (2, 2)).unwrap();
112    /// assert_eq!(array.shape(), (4, 4));
113    /// assert_eq!(array.nnz(), 8); // All elements in the blocks are non-zero
114    /// ```
115    pub fn new(
116        data: Vec<Vec<Vec<T>>>,
117        indices: Vec<Vec<usize>>,
118        indptr: Vec<usize>,
119        shape: (usize, usize),
120        block_size: (usize, usize),
121    ) -> SparseResult<Self> {
122        let (rows, cols) = shape;
123        let (r, c) = block_size;
124
125        if r == 0 || c == 0 {
126            return Err(SparseError::ValueError(
127                "Block dimensions must be positive".to_string(),
128            ));
129        }
130
131        // Calculate block dimensions
132        #[allow(clippy::manual_div_ceil)]
133        let block_rows = (rows + r - 1) / r; // Ceiling division
134        #[allow(clippy::manual_div_ceil)]
135        let block_cols = (cols + c - 1) / c; // Ceiling division
136
137        // Validate input
138        if indptr.len() != block_rows + 1 {
139            return Err(SparseError::DimensionMismatch {
140                expected: block_rows + 1,
141                found: indptr.len(),
142            });
143        }
144
145        if data.len() != indptr[block_rows] {
146            return Err(SparseError::DimensionMismatch {
147                expected: indptr[block_rows],
148                found: data.len(),
149            });
150        }
151
152        if indices.len() != data.len() {
153            return Err(SparseError::DimensionMismatch {
154                expected: data.len(),
155                found: indices.len(),
156            });
157        }
158
159        for block in data.iter() {
160            if block.len() != r {
161                return Err(SparseError::DimensionMismatch {
162                    expected: r,
163                    found: block.len(),
164                });
165            }
166
167            for row in block.iter() {
168                if row.len() != c {
169                    return Err(SparseError::DimensionMismatch {
170                        expected: c,
171                        found: row.len(),
172                    });
173                }
174            }
175        }
176
177        for idx_vec in indices.iter() {
178            if idx_vec.len() != 1 {
179                return Err(SparseError::ValueError(
180                    "Each index vector must contain exactly one block column index".to_string(),
181                ));
182            }
183            if idx_vec[0] >= block_cols {
184                return Err(SparseError::ValueError(format!(
185                    "index {} out of bounds (max {})",
186                    idx_vec[0],
187                    block_cols - 1
188                )));
189            }
190        }
191
192        Ok(BsrArray {
193            rows,
194            cols,
195            block_size,
196            block_rows,
197            block_cols,
198            data,
199            indices,
200            indptr,
201        })
202    }
203
204    /// Create a new empty BSR array
205    ///
206    /// # Arguments
207    ///
208    /// * `shape` - Tuple containing the array dimensions (rows, cols)
209    /// * `block_size` - Tuple containing the block dimensions (r, c)
210    ///
211    /// # Returns
212    ///
213    /// * A new empty BSR array
214    pub fn empty(shape: (usize, usize), block_size: (usize, usize)) -> SparseResult<Self> {
215        let (rows, cols) = shape;
216        let (r, c) = block_size;
217
218        if r == 0 || c == 0 {
219            return Err(SparseError::ValueError(
220                "Block dimensions must be positive".to_string(),
221            ));
222        }
223
224        // Calculate block dimensions
225        #[allow(clippy::manual_div_ceil)]
226        let block_rows = (rows + r - 1) / r; // Ceiling division
227        #[allow(clippy::manual_div_ceil)]
228        let block_cols = (cols + c - 1) / c; // Ceiling division
229
230        // Initialize empty BSR array
231        let data = Vec::new();
232        let indices = Vec::new();
233        let indptr = vec![0; block_rows + 1];
234
235        Ok(BsrArray {
236            rows,
237            cols,
238            block_size,
239            block_rows,
240            block_cols,
241            data,
242            indices,
243            indptr,
244        })
245    }
246
247    /// Convert triplets to BSR format
248    ///
249    /// # Arguments
250    ///
251    /// * `row` - Row indices
252    /// * `col` - Column indices
253    /// * `data` - Data values
254    /// * `shape` - Shape of the array
255    /// * `block_size` - Size of the blocks
256    ///
257    /// # Returns
258    ///
259    /// * A new BSR array
260    pub fn from_triplets(
261        row: &[usize],
262        col: &[usize],
263        data: &[T],
264        shape: (usize, usize),
265        block_size: (usize, usize),
266    ) -> SparseResult<Self> {
267        if row.len() != col.len() || row.len() != data.len() {
268            return Err(SparseError::InconsistentData {
269                reason: "Lengths of row, col, and data arrays must be equal".to_string(),
270            });
271        }
272
273        let (rows, cols) = shape;
274        let (r, c) = block_size;
275
276        if r == 0 || c == 0 {
277            return Err(SparseError::ValueError(
278                "Block dimensions must be positive".to_string(),
279            ));
280        }
281
282        // Calculate block dimensions
283        #[allow(clippy::manual_div_ceil)]
284        let block_rows = (rows + r - 1) / r; // Ceiling division
285        #[allow(clippy::manual_div_ceil)]
286        let block_cols = (cols + c - 1) / c; // Ceiling division
287
288        // First, we'll construct a temporary DOK-like representation for the blocks
289        let mut block_data = std::collections::HashMap::new();
290
291        // Assign each element to its corresponding block
292        for (&row_idx, (&col_idx, &val)) in row.iter().zip(col.iter().zip(data.iter())) {
293            if row_idx >= rows || col_idx >= cols {
294                return Err(SparseError::IndexOutOfBounds {
295                    index: (row_idx, col_idx),
296                    shape,
297                });
298            }
299
300            // Calculate block indices
301            let block_row = row_idx / r;
302            let block_col = col_idx / c;
303
304            // Calculate position within block
305            let block_row_pos = row_idx % r;
306            let block_col_pos = col_idx % c;
307
308            // Create or get the block
309            let block = block_data.entry((block_row, block_col)).or_insert_with(|| {
310                let block = vec![vec![T::zero(); c]; r];
311                block
312            });
313
314            // Set the value in the block
315            block[block_row_pos][block_col_pos] = val;
316        }
317
318        // Now convert the DOK-like format to BSR
319        let mut rowswith_blocks: Vec<usize> = block_data.keys().map(|&(row_, _)| row_).collect();
320        rowswith_blocks.sort();
321        rowswith_blocks.dedup();
322
323        // Create indptr array
324        let mut indptr = vec![0; block_rows + 1];
325        let mut current_nnz = 0;
326
327        // Sorted blocks data and indices
328        let mut data = Vec::new();
329        let mut indices = Vec::new();
330
331        for row_idx in 0..block_rows {
332            if rowswith_blocks.contains(&row_idx) {
333                // Get all blocks for this row
334                let mut row_blocks: Vec<(usize, Vec<Vec<T>>)> = block_data
335                    .iter()
336                    .filter(|&(&(r, _), _)| r == row_idx)
337                    .map(|(&(_, c), block)| (c, block.clone()))
338                    .collect();
339
340                // Sort by column index
341                row_blocks.sort_by_key(|&(col_, _)| col_);
342
343                // Add to data and indices
344                for (col, block) in row_blocks {
345                    data.push(block);
346                    indices.push(vec![col]);
347                    current_nnz += 1;
348                }
349            }
350
351            indptr[row_idx + 1] = current_nnz;
352        }
353
354        // Create the BSR array
355        BsrArray::new(data, indices, indptr, shape, block_size)
356    }
357
358    /// Convert to COO format triplets
359    fn to_coo_internal(&self) -> (Vec<usize>, Vec<usize>, Vec<T>) {
360        let (r, c) = self.block_size;
361        let mut row_indices = Vec::new();
362        let mut col_indices = Vec::new();
363        let mut values = Vec::new();
364
365        for block_row in 0..self.block_rows {
366            for k in self.indptr[block_row]..self.indptr[block_row + 1] {
367                let block_col = self.indices[k][0];
368                let block = &self.data[k];
369
370                // For each element in the block
371                for (i, block_row_data) in block.iter().enumerate().take(r) {
372                    let row = block_row * r + i;
373                    if row < self.rows {
374                        for (j, &value) in block_row_data.iter().enumerate().take(c) {
375                            let col = block_col * c + j;
376                            if col < self.cols && !value.is_zero() {
377                                row_indices.push(row);
378                                col_indices.push(col);
379                                values.push(value);
380                            }
381                        }
382                    }
383                }
384            }
385        }
386
387        (row_indices, col_indices, values)
388    }
389}
390
391impl<T> SparseArray<T> for BsrArray<T>
392where
393    T: Float
394        + Add<Output = T>
395        + Sub<Output = T>
396        + Mul<Output = T>
397        + Div<Output = T>
398        + Debug
399        + Copy
400        + 'static
401        + std::ops::AddAssign,
402{
403    fn shape(&self) -> (usize, usize) {
404        (self.rows, self.cols)
405    }
406
407    fn nnz(&self) -> usize {
408        let mut count = 0;
409
410        for block in &self.data {
411            for row in block {
412                for &val in row {
413                    if !val.is_zero() {
414                        count += 1;
415                    }
416                }
417            }
418        }
419
420        count
421    }
422
423    fn dtype(&self) -> &str {
424        "float" // Placeholder; ideally would return the actual type
425    }
426
427    fn to_array(&self) -> Array2<T> {
428        let mut result = Array2::zeros((self.rows, self.cols));
429        let (r, c) = self.block_size;
430
431        for block_row in 0..self.block_rows {
432            for k in self.indptr[block_row]..self.indptr[block_row + 1] {
433                let block_col = self.indices[k][0];
434                let block = &self.data[k];
435
436                // Copy block to dense array
437                for (i, block_row_data) in block.iter().enumerate().take(r) {
438                    let row = block_row * r + i;
439                    if row < self.rows {
440                        for (j, &value) in block_row_data.iter().enumerate().take(c) {
441                            let col = block_col * c + j;
442                            if col < self.cols {
443                                result[[row, col]] = value;
444                            }
445                        }
446                    }
447                }
448            }
449        }
450
451        result
452    }
453
454    fn toarray(&self) -> Array2<T> {
455        self.to_array()
456    }
457
458    fn to_coo(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
459        let (row_indices, col_indices, values) = self.to_coo_internal();
460        CooArray::from_triplets(
461            &row_indices,
462            &col_indices,
463            &values,
464            (self.rows, self.cols),
465            false,
466        )
467        .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
468    }
469
470    fn to_csr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
471        let (row_indices, col_indices, values) = self.to_coo_internal();
472        CsrArray::from_triplets(
473            &row_indices,
474            &col_indices,
475            &values,
476            (self.rows, self.cols),
477            false,
478        )
479        .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
480    }
481
482    fn to_csc(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
483        let (row_indices, col_indices, values) = self.to_coo_internal();
484        CscArray::from_triplets(
485            &row_indices,
486            &col_indices,
487            &values,
488            (self.rows, self.cols),
489            false,
490        )
491        .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
492    }
493
494    fn to_dok(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
495        let (row_indices, col_indices, values) = self.to_coo_internal();
496        DokArray::from_triplets(&row_indices, &col_indices, &values, (self.rows, self.cols))
497            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
498    }
499
500    fn to_lil(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
501        let (row_indices, col_indices, values) = self.to_coo_internal();
502        LilArray::from_triplets(&row_indices, &col_indices, &values, (self.rows, self.cols))
503            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
504    }
505
506    fn to_dia(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
507        let (row_indices, col_indices, values) = self.to_coo_internal();
508        DiaArray::from_triplets(&row_indices, &col_indices, &values, (self.rows, self.cols))
509            .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
510    }
511
512    fn to_bsr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
513        Ok(Box::new(self.clone()))
514    }
515
516    fn add(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
517        // For efficiency, convert both to CSR for addition
518        let csr_self = self.to_csr()?;
519        let csr_other = other.to_csr()?;
520        csr_self.add(&*csr_other)
521    }
522
523    fn sub(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
524        // For efficiency, convert both to CSR for subtraction
525        let csr_self = self.to_csr()?;
526        let csr_other = other.to_csr()?;
527        csr_self.sub(&*csr_other)
528    }
529
530    fn mul(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
531        // For efficiency, convert both to CSR for element-wise multiplication
532        let csr_self = self.to_csr()?;
533        let csr_other = other.to_csr()?;
534        csr_self.mul(&*csr_other)
535    }
536
537    fn div(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
538        // For efficiency, convert both to CSR for element-wise division
539        let csr_self = self.to_csr()?;
540        let csr_other = other.to_csr()?;
541        csr_self.div(&*csr_other)
542    }
543
544    fn dot(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
545        let (_, n) = self.shape();
546        let (p, q) = other.shape();
547
548        if n != p {
549            return Err(SparseError::DimensionMismatch {
550                expected: n,
551                found: p,
552            });
553        }
554
555        // If other is a vector (thin matrix), we can use optimized BSR-Vector multiplication
556        if q == 1 {
557            // Get the vector from other
558            let other_array = other.to_array();
559            let vec_view = other_array.column(0);
560
561            // Perform BSR-Vector multiplication
562            let result = self.dot_vector(&vec_view)?;
563
564            // Convert to a matrix - create a COO from triplets
565            let mut rows = Vec::new();
566            let mut cols = Vec::new();
567            let mut values = Vec::new();
568
569            for (i, &val) in result.iter().enumerate() {
570                if !val.is_zero() {
571                    rows.push(i);
572                    cols.push(0);
573                    values.push(val);
574                }
575            }
576
577            CooArray::from_triplets(&rows, &cols, &values, (result.len(), 1), false)
578                .map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
579        } else {
580            // For general matrix-matrix multiplication, convert to CSR
581            let csr_self = self.to_csr()?;
582            csr_self.dot(other)
583        }
584    }
585
586    fn dot_vector(&self, other: &ArrayView1<T>) -> SparseResult<Array1<T>> {
587        let (rows, cols) = self.shape();
588        let (r, c) = self.block_size;
589
590        if cols != other.len() {
591            return Err(SparseError::DimensionMismatch {
592                expected: cols,
593                found: other.len(),
594            });
595        }
596
597        let mut result = Array1::zeros(rows);
598
599        for block_row in 0..self.block_rows {
600            for k in self.indptr[block_row]..self.indptr[block_row + 1] {
601                let block_col = self.indices[k][0];
602                let block = &self.data[k];
603
604                // For each element in the block
605                for (i, block_row_data) in block.iter().enumerate().take(r) {
606                    let row = block_row * r + i;
607                    if row < self.rows {
608                        for (j, &value) in block_row_data.iter().enumerate().take(c) {
609                            let col = block_col * c + j;
610                            if col < self.cols {
611                                result[row] += value * other[col];
612                            }
613                        }
614                    }
615                }
616            }
617        }
618
619        Ok(result)
620    }
621
622    fn transpose(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
623        // For efficiency, convert to COO, transpose, then convert back to BSR
624        self.to_coo()?.transpose()?.to_bsr()
625    }
626
627    fn copy(&self) -> Box<dyn SparseArray<T>> {
628        Box::new(self.clone())
629    }
630
631    fn get(&self, i: usize, j: usize) -> T {
632        if i >= self.rows || j >= self.cols {
633            return T::zero();
634        }
635
636        let (r, c) = self.block_size;
637        let block_row = i / r;
638        let block_col = j / c;
639        let block_row_pos = i % r;
640        let block_col_pos = j % c;
641
642        // Search for the block in the row
643        for k in self.indptr[block_row]..self.indptr[block_row + 1] {
644            if self.indices[k][0] == block_col {
645                return self.data[k][block_row_pos][block_col_pos];
646            }
647        }
648
649        T::zero()
650    }
651
652    fn set(&mut self, i: usize, j: usize, value: T) -> SparseResult<()> {
653        if i >= self.rows || j >= self.cols {
654            return Err(SparseError::IndexOutOfBounds {
655                index: (i, j),
656                shape: (self.rows, self.cols),
657            });
658        }
659
660        let (r, c) = self.block_size;
661        let block_row = i / r;
662        let block_col = j / c;
663        let block_row_pos = i % r;
664        let block_col_pos = j % c;
665
666        // Search for the block in the row
667        for k in self.indptr[block_row]..self.indptr[block_row + 1] {
668            if self.indices[k][0] == block_col {
669                // Block exists, update value
670                self.data[k][block_row_pos][block_col_pos] = value;
671                return Ok(());
672            }
673        }
674
675        // Block doesn't exist, we need to create it
676        if !value.is_zero() {
677            // Find position to insert
678            let pos = self.indptr[block_row + 1];
679
680            // Create new block
681            let mut block = vec![vec![T::zero(); c]; r];
682            block[block_row_pos][block_col_pos] = value;
683
684            // Insert block, indices
685            self.data.insert(pos, block);
686            self.indices.insert(pos, vec![block_col]);
687
688            // Update indptr for subsequent rows
689            for k in (block_row + 1)..=self.block_rows {
690                self.indptr[k] += 1;
691            }
692
693            Ok(())
694        } else {
695            // If value is zero and block doesn't exist, do nothing
696            Ok(())
697        }
698    }
699
700    fn eliminate_zeros(&mut self) {
701        // No need to use block_size variables here
702        let mut new_data = Vec::new();
703        let mut new_indices = Vec::new();
704        let mut new_indptr = vec![0];
705        let mut current_nnz = 0;
706
707        for block_row in 0..self.block_rows {
708            for k in self.indptr[block_row]..self.indptr[block_row + 1] {
709                let block_col = self.indices[k][0];
710                let block = &self.data[k];
711
712                // Check if block has any non-zero elements
713                let mut has_nonzero = false;
714                for row in block {
715                    for &val in row {
716                        if !val.is_zero() {
717                            has_nonzero = true;
718                            break;
719                        }
720                    }
721                    if has_nonzero {
722                        break;
723                    }
724                }
725
726                if has_nonzero {
727                    new_data.push(block.clone());
728                    new_indices.push(vec![block_col]);
729                    current_nnz += 1;
730                }
731            }
732
733            new_indptr.push(current_nnz);
734        }
735
736        self.data = new_data;
737        self.indices = new_indices;
738        self.indptr = new_indptr;
739    }
740
741    fn sort_indices(&mut self) {
742        // No need to use block_size variables here
743        let mut new_data = Vec::new();
744        let mut new_indices = Vec::new();
745        let mut new_indptr = vec![0];
746        let mut current_nnz = 0;
747
748        for block_row in 0..self.block_rows {
749            // Get blocks for this row
750            let mut row_blocks = Vec::new();
751            for k in self.indptr[block_row]..self.indptr[block_row + 1] {
752                row_blocks.push((self.indices[k][0], self.data[k].clone()));
753            }
754
755            // Sort by column index
756            row_blocks.sort_by_key(|&(col_, _)| col_);
757
758            // Add sorted blocks to new data structures
759            for (col, block) in row_blocks {
760                new_data.push(block);
761                new_indices.push(vec![col]);
762                current_nnz += 1;
763            }
764
765            new_indptr.push(current_nnz);
766        }
767
768        self.data = new_data;
769        self.indices = new_indices;
770        self.indptr = new_indptr;
771    }
772
773    fn sorted_indices(&self) -> Box<dyn SparseArray<T>> {
774        let mut result = self.clone();
775        result.sort_indices();
776        Box::new(result)
777    }
778
779    fn has_sorted_indices(&self) -> bool {
780        for block_row in 0..self.block_rows {
781            let mut prev_col = None;
782
783            for k in self.indptr[block_row]..self.indptr[block_row + 1] {
784                let col = self.indices[k][0];
785
786                if let Some(prev) = prev_col {
787                    if col <= prev {
788                        return false;
789                    }
790                }
791
792                prev_col = Some(col);
793            }
794        }
795
796        true
797    }
798
799    fn sum(&self, axis: Option<usize>) -> SparseResult<SparseSum<T>> {
800        match axis {
801            None => {
802                // Sum all elements
803                let mut total = T::zero();
804
805                for block in &self.data {
806                    for row in block {
807                        for &val in row {
808                            total += val;
809                        }
810                    }
811                }
812
813                Ok(SparseSum::Scalar(total))
814            }
815            Some(0) => {
816                // Sum along rows (result is 1 x cols)
817                let mut result = vec![T::zero(); self.cols];
818                let (r, c) = self.block_size;
819
820                for block_row in 0..self.block_rows {
821                    for k in self.indptr[block_row]..self.indptr[block_row + 1] {
822                        let block_col = self.indices[k][0];
823                        let block = &self.data[k];
824
825                        for block_row_data in block.iter().take(r) {
826                            for (j, &value) in block_row_data.iter().enumerate().take(c) {
827                                let col = block_col * c + j;
828                                if col < self.cols {
829                                    result[col] += value;
830                                }
831                            }
832                        }
833                    }
834                }
835
836                // Create a sparse array from the result
837                let mut row_indices = Vec::new();
838                let mut col_indices = Vec::new();
839                let mut values = Vec::new();
840
841                for (j, &val) in result.iter().enumerate() {
842                    if !val.is_zero() {
843                        row_indices.push(0);
844                        col_indices.push(j);
845                        values.push(val);
846                    }
847                }
848
849                match CooArray::from_triplets(
850                    &row_indices,
851                    &col_indices,
852                    &values,
853                    (1, self.cols),
854                    false,
855                ) {
856                    Ok(array) => Ok(SparseSum::SparseArray(Box::new(array))),
857                    Err(e) => Err(e),
858                }
859            }
860            Some(1) => {
861                // Sum along columns (result is rows x 1)
862                let mut result = vec![T::zero(); self.rows];
863                let (r, c) = self.block_size;
864
865                for block_row in 0..self.block_rows {
866                    for k in self.indptr[block_row]..self.indptr[block_row + 1] {
867                        let block = &self.data[k];
868
869                        for (i, block_row_data) in block.iter().enumerate().take(r) {
870                            let row = block_row * r + i;
871                            if row < self.rows {
872                                for &value in block_row_data.iter().take(c) {
873                                    result[row] += value;
874                                }
875                            }
876                        }
877                    }
878                }
879
880                // Create a sparse array from the result
881                let mut row_indices = Vec::new();
882                let mut col_indices = Vec::new();
883                let mut values = Vec::new();
884
885                for (i, &val) in result.iter().enumerate() {
886                    if !val.is_zero() {
887                        row_indices.push(i);
888                        col_indices.push(0);
889                        values.push(val);
890                    }
891                }
892
893                match CooArray::from_triplets(
894                    &row_indices,
895                    &col_indices,
896                    &values,
897                    (self.rows, 1),
898                    false,
899                ) {
900                    Ok(array) => Ok(SparseSum::SparseArray(Box::new(array))),
901                    Err(e) => Err(e),
902                }
903            }
904            _ => Err(SparseError::InvalidAxis),
905        }
906    }
907
908    fn max(&self) -> T {
909        let mut max_val = T::neg_infinity();
910
911        for block in &self.data {
912            for row in block {
913                for &val in row {
914                    max_val = max_val.max(val);
915                }
916            }
917        }
918
919        // If no elements or all negative infinity, return zero
920        if max_val == T::neg_infinity() {
921            T::zero()
922        } else {
923            max_val
924        }
925    }
926
927    fn min(&self) -> T {
928        let mut min_val = T::infinity();
929        let mut has_nonzero = false;
930
931        for block in &self.data {
932            for row in block {
933                for &val in row {
934                    if !val.is_zero() {
935                        has_nonzero = true;
936                        min_val = min_val.min(val);
937                    }
938                }
939            }
940        }
941
942        // If no non-zero elements, return zero
943        if !has_nonzero {
944            T::zero()
945        } else {
946            min_val
947        }
948    }
949
950    fn find(&self) -> (Array1<usize>, Array1<usize>, Array1<T>) {
951        let (row_indices, col_indices, values) = self.to_coo_internal();
952
953        (
954            Array1::from_vec(row_indices),
955            Array1::from_vec(col_indices),
956            Array1::from_vec(values),
957        )
958    }
959
960    fn slice(
961        &self,
962        row_range: (usize, usize),
963        col_range: (usize, usize),
964    ) -> SparseResult<Box<dyn SparseArray<T>>> {
965        let (start_row, end_row) = row_range;
966        let (start_col, end_col) = col_range;
967        let (rows, cols) = self.shape();
968
969        if start_row >= rows || end_row > rows || start_col >= cols || end_col > cols {
970            return Err(SparseError::IndexOutOfBounds {
971                index: (start_row.max(end_row), start_col.max(end_col)),
972                shape: (rows, cols),
973            });
974        }
975
976        if start_row >= end_row || start_col >= end_col {
977            return Err(SparseError::InvalidSliceRange);
978        }
979
980        // Convert to COO, slice, then convert back to BSR
981        let coo = self.to_coo()?;
982        coo.slice(row_range, col_range)?.to_bsr()
983    }
984
985    fn as_any(&self) -> &dyn std::any::Any {
986        self
987    }
988}
989
990// Implement Display for BsrArray for better debugging
991impl<T> fmt::Display for BsrArray<T>
992where
993    T: Float
994        + Add<Output = T>
995        + Sub<Output = T>
996        + Mul<Output = T>
997        + Div<Output = T>
998        + Debug
999        + Copy
1000        + 'static
1001        + std::ops::AddAssign,
1002{
1003    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1004        writeln!(
1005            f,
1006            "BsrArray of shape {:?} with {} stored elements",
1007            (self.rows, self.cols),
1008            self.nnz()
1009        )?;
1010        writeln!(f, "Block size: {:?}", self.block_size)?;
1011        writeln!(f, "Number of blocks: {}", self.data.len())?;
1012
1013        if self.data.len() <= 5 {
1014            for block_row in 0..self.block_rows {
1015                for k in self.indptr[block_row]..self.indptr[block_row + 1] {
1016                    let block_col = self.indices[k][0];
1017                    let block = &self.data[k];
1018
1019                    writeln!(f, "Block at ({block_row}, {block_col}): ")?;
1020                    for row in block {
1021                        write!(f, "  [")?;
1022                        for (j, &val) in row.iter().enumerate() {
1023                            if j > 0 {
1024                                write!(f, ", ")?;
1025                            }
1026                            write!(f, "{val:?}")?;
1027                        }
1028                        writeln!(f, "]")?;
1029                    }
1030                }
1031            }
1032        } else {
1033            writeln!(f, "({} blocks total)", self.data.len())?;
1034        }
1035
1036        Ok(())
1037    }
1038}
1039
1040#[cfg(test)]
1041mod tests {
1042    use super::*;
1043
1044    #[test]
1045    fn test_bsr_array_create() {
1046        // Create a 4x4 sparse array with 2x2 blocks
1047        // [1 2 0 0]
1048        // [3 4 0 0]
1049        // [0 0 5 6]
1050        // [0 0 7 8]
1051
1052        let block1 = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
1053        let block2 = vec![vec![5.0, 6.0], vec![7.0, 8.0]];
1054
1055        let data = vec![block1, block2];
1056        let indices = vec![vec![0], vec![1]];
1057        let indptr = vec![0, 1, 2];
1058
1059        let array = BsrArray::new(data, indices, indptr, (4, 4), (2, 2)).unwrap();
1060
1061        assert_eq!(array.shape(), (4, 4));
1062        assert_eq!(array.block_size, (2, 2));
1063        assert_eq!(array.nnz(), 8); // All elements in the blocks are non-zero
1064
1065        // Test values
1066        assert_eq!(array.get(0, 0), 1.0);
1067        assert_eq!(array.get(0, 1), 2.0);
1068        assert_eq!(array.get(1, 0), 3.0);
1069        assert_eq!(array.get(1, 1), 4.0);
1070        assert_eq!(array.get(2, 2), 5.0);
1071        assert_eq!(array.get(2, 3), 6.0);
1072        assert_eq!(array.get(3, 2), 7.0);
1073        assert_eq!(array.get(3, 3), 8.0);
1074        assert_eq!(array.get(0, 2), 0.0); // zero element
1075    }
1076
1077    #[test]
1078    fn test_bsr_array_from_triplets() {
1079        // Create a 4x4 sparse array with 2x2 blocks
1080        let rows = vec![0, 0, 1, 1, 2, 2, 3, 3];
1081        let cols = vec![0, 1, 0, 1, 2, 3, 2, 3];
1082        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
1083        let shape = (4, 4);
1084        let block_size = (2, 2);
1085
1086        let array = BsrArray::from_triplets(&rows, &cols, &data, shape, block_size).unwrap();
1087
1088        assert_eq!(array.shape(), (4, 4));
1089        assert_eq!(array.block_size, (2, 2));
1090        assert_eq!(array.nnz(), 8);
1091
1092        // Test values
1093        assert_eq!(array.get(0, 0), 1.0);
1094        assert_eq!(array.get(0, 1), 2.0);
1095        assert_eq!(array.get(1, 0), 3.0);
1096        assert_eq!(array.get(1, 1), 4.0);
1097        assert_eq!(array.get(2, 2), 5.0);
1098        assert_eq!(array.get(2, 3), 6.0);
1099        assert_eq!(array.get(3, 2), 7.0);
1100        assert_eq!(array.get(3, 3), 8.0);
1101        assert_eq!(array.get(0, 2), 0.0); // zero element
1102    }
1103
1104    #[test]
1105    fn test_bsr_array_conversion() {
1106        // Create a 4x4 sparse array with 2x2 blocks
1107        let block1 = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
1108        let block2 = vec![vec![5.0, 6.0], vec![7.0, 8.0]];
1109
1110        let data = vec![block1, block2];
1111        let indices = vec![vec![0], vec![1]];
1112        let indptr = vec![0, 1, 2];
1113
1114        let array = BsrArray::new(data, indices, indptr, (4, 4), (2, 2)).unwrap();
1115
1116        // Convert to COO and check
1117        let coo = array.to_coo().unwrap();
1118        assert_eq!(coo.shape(), (4, 4));
1119        assert_eq!(coo.nnz(), 8);
1120
1121        // Convert to CSR and check
1122        let csr = array.to_csr().unwrap();
1123        assert_eq!(csr.shape(), (4, 4));
1124        assert_eq!(csr.nnz(), 8);
1125
1126        // Convert to dense and check
1127        let dense = array.to_array();
1128        let expected = Array2::from_shape_vec(
1129            (4, 4),
1130            vec![
1131                1.0, 2.0, 0.0, 0.0, 3.0, 4.0, 0.0, 0.0, 0.0, 0.0, 5.0, 6.0, 0.0, 0.0, 7.0, 8.0,
1132            ],
1133        )
1134        .unwrap();
1135        assert_eq!(dense, expected);
1136    }
1137
1138    #[test]
1139    fn test_bsr_array_operations() {
1140        // Create two simple block arrays
1141        let block1 = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
1142        let data1 = vec![block1];
1143        let indices1 = vec![vec![0]];
1144        let indptr1 = vec![0, 1];
1145        let array1 = BsrArray::new(data1, indices1, indptr1, (2, 2), (2, 2)).unwrap();
1146
1147        let block2 = vec![vec![5.0, 6.0], vec![7.0, 8.0]];
1148        let data2 = vec![block2];
1149        let indices2 = vec![vec![0]];
1150        let indptr2 = vec![0, 1];
1151        let array2 = BsrArray::new(data2, indices2, indptr2, (2, 2), (2, 2)).unwrap();
1152
1153        // Test addition
1154        let sum = array1.add(&array2).unwrap();
1155        assert_eq!(sum.shape(), (2, 2));
1156        assert_eq!(sum.get(0, 0), 6.0); // 1+5
1157        assert_eq!(sum.get(0, 1), 8.0); // 2+6
1158        assert_eq!(sum.get(1, 0), 10.0); // 3+7
1159        assert_eq!(sum.get(1, 1), 12.0); // 4+8
1160
1161        // Test element-wise multiplication
1162        let product = array1.mul(&array2).unwrap();
1163        assert_eq!(product.shape(), (2, 2));
1164        assert_eq!(product.get(0, 0), 5.0); // 1*5
1165        assert_eq!(product.get(0, 1), 12.0); // 2*6
1166        assert_eq!(product.get(1, 0), 21.0); // 3*7
1167        assert_eq!(product.get(1, 1), 32.0); // 4*8
1168
1169        // Test dot product (matrix multiplication)
1170        let dot = array1.dot(&array2).unwrap();
1171        assert_eq!(dot.shape(), (2, 2));
1172        assert_eq!(dot.get(0, 0), 19.0); // 1*5 + 2*7
1173        assert_eq!(dot.get(0, 1), 22.0); // 1*6 + 2*8
1174        assert_eq!(dot.get(1, 0), 43.0); // 3*5 + 4*7
1175        assert_eq!(dot.get(1, 1), 50.0); // 3*6 + 4*8
1176    }
1177
1178    #[test]
1179    fn test_bsr_array_dot_vector() {
1180        // Create a 4x4 sparse array with 2x2 blocks
1181        let block1 = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
1182        let block2 = vec![vec![5.0, 6.0], vec![7.0, 8.0]];
1183
1184        let data = vec![block1, block2];
1185        let indices = vec![vec![0], vec![1]];
1186        let indptr = vec![0, 1, 2];
1187
1188        let array = BsrArray::new(data, indices, indptr, (4, 4), (2, 2)).unwrap();
1189
1190        // Create a vector
1191        let vector = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
1192
1193        // Test matrix-vector multiplication
1194        let result = array.dot_vector(&vector.view()).unwrap();
1195
1196        // Expected: [1*1 + 2*2 + 0*3 + 0*4, 3*1 + 4*2 + 0*3 + 0*4,
1197        //            0*1 + 0*2 + 5*3 + 6*4, 0*1 + 0*2 + 7*3 + 8*4]
1198        // = [5, 11, 39, 53]
1199        let expected = Array1::from_vec(vec![5.0, 11.0, 39.0, 53.0]);
1200        assert_eq!(result, expected);
1201    }
1202
1203    #[test]
1204    fn test_bsr_array_sum() {
1205        // Create a 4x4 sparse array with 2x2 blocks
1206        let block1 = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
1207        let block2 = vec![vec![5.0, 6.0], vec![7.0, 8.0]];
1208
1209        let data = vec![block1, block2];
1210        let indices = vec![vec![0], vec![1]];
1211        let indptr = vec![0, 1, 2];
1212
1213        let array = BsrArray::new(data, indices, indptr, (4, 4), (2, 2)).unwrap();
1214
1215        // Test sum of entire array
1216        if let SparseSum::Scalar(sum) = array.sum(None).unwrap() {
1217            assert_eq!(sum, 36.0); // 1+2+3+4+5+6+7+8 = 36
1218        } else {
1219            panic!("Expected SparseSum::Scalar");
1220        }
1221
1222        // Test sum along rows (result should be 1 x 4)
1223        if let SparseSum::SparseArray(row_sum) = array.sum(Some(0)).unwrap() {
1224            assert_eq!(row_sum.shape(), (1, 4));
1225            assert_eq!(row_sum.get(0, 0), 4.0); // 1+3
1226            assert_eq!(row_sum.get(0, 1), 6.0); // 2+4
1227            assert_eq!(row_sum.get(0, 2), 12.0); // 5+7
1228            assert_eq!(row_sum.get(0, 3), 14.0); // 6+8
1229        } else {
1230            panic!("Expected SparseSum::SparseArray");
1231        }
1232
1233        // Test sum along columns (result should be 4 x 1)
1234        if let SparseSum::SparseArray(col_sum) = array.sum(Some(1)).unwrap() {
1235            assert_eq!(col_sum.shape(), (4, 1));
1236            assert_eq!(col_sum.get(0, 0), 3.0); // 1+2
1237            assert_eq!(col_sum.get(1, 0), 7.0); // 3+4
1238            assert_eq!(col_sum.get(2, 0), 11.0); // 5+6
1239            assert_eq!(col_sum.get(3, 0), 15.0); // 7+8
1240        } else {
1241            panic!("Expected SparseSum::SparseArray");
1242        }
1243    }
1244}