scirs2_sparse/
banded.rs

1//! Banded matrix format (legacy matrix API)
2
3use crate::banded_array::BandedArray;
4use crate::error::SparseResult;
5use crate::sparray::SparseArray;
6use ndarray::Array2;
7use num_traits::{Float, One, Zero};
8use std::fmt::Debug;
9
10/// Legacy banded matrix wrapper around BandedArray
11pub type BandedMatrix<T> = BandedArray<T>;
12
13impl<T> BandedMatrix<T>
14where
15    T: Float
16        + Debug
17        + std::fmt::Display
18        + Copy
19        + Zero
20        + One
21        + Send
22        + Sync
23        + 'static
24        + std::ops::AddAssign,
25{
26    /// Matrix multiplication (for legacy API compatibility)
27    pub fn matmul(&self, other: &BandedMatrix<T>) -> SparseResult<BandedMatrix<T>> {
28        // Convert to dense for multiplication, then back to banded
29        let a_dense = self.to_array();
30        let b_dense = other.to_array();
31
32        if a_dense.ncols() != b_dense.nrows() {
33            return Err(crate::error::SparseError::DimensionMismatch {
34                expected: a_dense.ncols(),
35                found: b_dense.nrows(),
36            });
37        }
38
39        let result_dense = a_dense.dot(&b_dense);
40
41        // Estimate bandwidth of result
42        let max_bandwidth = self.kl() + self.ku() + other.kl() + other.ku();
43
44        // Extract banded structure from result
45        Self::from_dense(&result_dense, max_bandwidth, max_bandwidth)
46    }
47
48    /// Create banded matrix from dense array
49    pub fn from_dense(dense: &Array2<T>, kl: usize, ku: usize) -> SparseResult<Self> {
50        let (rows, cols) = dense.dim();
51        let mut result = Self::zeros((rows, cols), kl, ku);
52
53        for i in 0..rows {
54            for j in 0..cols {
55                if result.is_in_band(i, j) {
56                    let val = dense[[i, j]];
57                    if !val.is_zero() {
58                        result.set_unchecked(i, j, val);
59                    }
60                }
61            }
62        }
63
64        Ok(result)
65    }
66
67    /// Get a mutable reference to an element (legacy API)
68    pub fn get_mut(&mut self, row: usize, col: usize) -> Option<&mut T> {
69        if !self.is_in_band(row, col) {
70            return None;
71        }
72
73        let band_idx = self.ku() + col - row;
74        if band_idx < self.kl() + self.ku() + 1 && row < self.shape().0 {
75            Some(&mut self.data_mut()[[band_idx, row]])
76        } else {
77            None
78        }
79    }
80
81    /// Get mutable reference to data (private helper)  
82    #[allow(dead_code)]
83    fn banded_data_mut(&mut self) -> &mut Array2<T> {
84        BandedArray::data_mut(self)
85    }
86
87    /// Set element (legacy API)
88    pub fn set(&mut self, row: usize, col: usize, value: T) -> SparseResult<()> {
89        if !self.is_in_band(row, col) {
90            if !value.is_zero() {
91                return Err(crate::error::SparseError::ValueError(format!(
92                    "Cannot set non-zero element at ({row}, {col}) outside band structure"
93                )));
94            }
95            return Ok(());
96        }
97
98        self.set_unchecked(row, col, value);
99        Ok(())
100    }
101}
102
103#[cfg(test)]
104mod tests {
105    use super::*;
106    use approx::assert_relative_eq;
107
108    #[test]
109    fn test_banded_matrix_creation() {
110        let diag = vec![1.0, 2.0, 3.0];
111        let lower = vec![4.0, 5.0];
112        let upper = vec![6.0, 7.0];
113
114        let matrix = BandedMatrix::tridiagonal(&diag, &lower, &upper).unwrap();
115
116        assert_eq!(matrix.shape(), (3, 3));
117        assert_eq!(matrix.get(0, 0), 1.0);
118        assert_eq!(matrix.get(1, 1), 2.0);
119        assert_eq!(matrix.get(2, 2), 3.0);
120        assert_eq!(matrix.get(1, 0), 4.0);
121        assert_eq!(matrix.get(2, 1), 5.0);
122        assert_eq!(matrix.get(0, 1), 6.0);
123        assert_eq!(matrix.get(1, 2), 7.0);
124    }
125
126    #[test]
127    fn test_banded_matrix_set() {
128        let mut matrix = BandedMatrix::<f64>::zeros((3, 3), 1, 1);
129
130        // Should succeed for in-band elements
131        assert!(matrix.set(0, 0, 1.0).is_ok());
132        assert!(matrix.set(0, 1, 2.0).is_ok());
133        assert!(matrix.set(1, 0, 3.0).is_ok());
134
135        // Should fail for out-of-band non-zero elements
136        assert!(matrix.set(0, 2, 4.0).is_err());
137
138        // Should succeed for out-of-band zero elements
139        assert!(matrix.set(0, 2, 0.0).is_ok());
140
141        assert_eq!(matrix.get(0, 0), 1.0);
142        assert_eq!(matrix.get(0, 1), 2.0);
143        assert_eq!(matrix.get(1, 0), 3.0);
144    }
145
146    #[test]
147    fn test_banded_matrix_matmul() {
148        let a = BandedMatrix::tridiagonal(&[2.0, 2.0, 2.0], &[1.0, 1.0], &[1.0, 1.0]).unwrap();
149
150        let b = BandedMatrix::tridiagonal(&[1.0, 1.0, 1.0], &[0.5, 0.5], &[0.5, 0.5]).unwrap();
151
152        let c = a.matmul(&b).unwrap();
153
154        // Verify some elements of the result
155        assert!(c.shape() == (3, 3));
156
157        // Manual verification for (0,0): 2*1 + 1*0.5 = 2.5
158        assert_relative_eq!(c.get(0, 0), 2.5, epsilon = 1e-10);
159    }
160
161    #[test]
162    fn test_from_dense() {
163        let dense =
164            Array2::from_shape_vec((3, 3), vec![1.0, 2.0, 0.0, 3.0, 4.0, 5.0, 0.0, 6.0, 7.0])
165                .unwrap();
166
167        let banded = BandedMatrix::from_dense(&dense, 1, 1).unwrap();
168
169        assert_eq!(banded.get(0, 0), 1.0);
170        assert_eq!(banded.get(0, 1), 2.0);
171        assert_eq!(banded.get(1, 0), 3.0);
172        assert_eq!(banded.get(1, 1), 4.0);
173        assert_eq!(banded.get(1, 2), 5.0);
174        assert_eq!(banded.get(2, 1), 6.0);
175        assert_eq!(banded.get(2, 2), 7.0);
176
177        // Out-of-band elements should be zero
178        assert_eq!(banded.get(0, 2), 0.0);
179        assert_eq!(banded.get(2, 0), 0.0);
180    }
181}