1use crate::banded_array::BandedArray;
4use crate::error::SparseResult;
5use crate::sparray::SparseArray;
6use ndarray::Array2;
7use num_traits::{Float, One, Zero};
8use std::fmt::Debug;
9
10pub type BandedMatrix<T> = BandedArray<T>;
12
13impl<T> BandedMatrix<T>
14where
15 T: Float
16 + Debug
17 + std::fmt::Display
18 + Copy
19 + Zero
20 + One
21 + Send
22 + Sync
23 + 'static
24 + std::ops::AddAssign,
25{
26 pub fn matmul(&self, other: &BandedMatrix<T>) -> SparseResult<BandedMatrix<T>> {
28 let a_dense = self.to_array();
30 let b_dense = other.to_array();
31
32 if a_dense.ncols() != b_dense.nrows() {
33 return Err(crate::error::SparseError::DimensionMismatch {
34 expected: a_dense.ncols(),
35 found: b_dense.nrows(),
36 });
37 }
38
39 let result_dense = a_dense.dot(&b_dense);
40
41 let max_bandwidth = self.kl() + self.ku() + other.kl() + other.ku();
43
44 Self::from_dense(&result_dense, max_bandwidth, max_bandwidth)
46 }
47
48 pub fn from_dense(dense: &Array2<T>, kl: usize, ku: usize) -> SparseResult<Self> {
50 let (rows, cols) = dense.dim();
51 let mut result = Self::zeros((rows, cols), kl, ku);
52
53 for i in 0..rows {
54 for j in 0..cols {
55 if result.is_in_band(i, j) {
56 let val = dense[[i, j]];
57 if !val.is_zero() {
58 result.set_unchecked(i, j, val);
59 }
60 }
61 }
62 }
63
64 Ok(result)
65 }
66
67 pub fn get_mut(&mut self, row: usize, col: usize) -> Option<&mut T> {
69 if !self.is_in_band(row, col) {
70 return None;
71 }
72
73 let band_idx = self.ku() + col - row;
74 if band_idx < self.kl() + self.ku() + 1 && row < self.shape().0 {
75 Some(&mut self.data_mut()[[band_idx, row]])
76 } else {
77 None
78 }
79 }
80
81 #[allow(dead_code)]
83 fn banded_data_mut(&mut self) -> &mut Array2<T> {
84 BandedArray::data_mut(self)
85 }
86
87 pub fn set(&mut self, row: usize, col: usize, value: T) -> SparseResult<()> {
89 if !self.is_in_band(row, col) {
90 if !value.is_zero() {
91 return Err(crate::error::SparseError::ValueError(format!(
92 "Cannot set non-zero element at ({row}, {col}) outside band structure"
93 )));
94 }
95 return Ok(());
96 }
97
98 self.set_unchecked(row, col, value);
99 Ok(())
100 }
101}
102
103#[cfg(test)]
104mod tests {
105 use super::*;
106 use approx::assert_relative_eq;
107
108 #[test]
109 fn test_banded_matrix_creation() {
110 let diag = vec![1.0, 2.0, 3.0];
111 let lower = vec![4.0, 5.0];
112 let upper = vec![6.0, 7.0];
113
114 let matrix = BandedMatrix::tridiagonal(&diag, &lower, &upper).unwrap();
115
116 assert_eq!(matrix.shape(), (3, 3));
117 assert_eq!(matrix.get(0, 0), 1.0);
118 assert_eq!(matrix.get(1, 1), 2.0);
119 assert_eq!(matrix.get(2, 2), 3.0);
120 assert_eq!(matrix.get(1, 0), 4.0);
121 assert_eq!(matrix.get(2, 1), 5.0);
122 assert_eq!(matrix.get(0, 1), 6.0);
123 assert_eq!(matrix.get(1, 2), 7.0);
124 }
125
126 #[test]
127 fn test_banded_matrix_set() {
128 let mut matrix = BandedMatrix::<f64>::zeros((3, 3), 1, 1);
129
130 assert!(matrix.set(0, 0, 1.0).is_ok());
132 assert!(matrix.set(0, 1, 2.0).is_ok());
133 assert!(matrix.set(1, 0, 3.0).is_ok());
134
135 assert!(matrix.set(0, 2, 4.0).is_err());
137
138 assert!(matrix.set(0, 2, 0.0).is_ok());
140
141 assert_eq!(matrix.get(0, 0), 1.0);
142 assert_eq!(matrix.get(0, 1), 2.0);
143 assert_eq!(matrix.get(1, 0), 3.0);
144 }
145
146 #[test]
147 fn test_banded_matrix_matmul() {
148 let a = BandedMatrix::tridiagonal(&[2.0, 2.0, 2.0], &[1.0, 1.0], &[1.0, 1.0]).unwrap();
149
150 let b = BandedMatrix::tridiagonal(&[1.0, 1.0, 1.0], &[0.5, 0.5], &[0.5, 0.5]).unwrap();
151
152 let c = a.matmul(&b).unwrap();
153
154 assert!(c.shape() == (3, 3));
156
157 assert_relative_eq!(c.get(0, 0), 2.5, epsilon = 1e-10);
159 }
160
161 #[test]
162 fn test_from_dense() {
163 let dense =
164 Array2::from_shape_vec((3, 3), vec![1.0, 2.0, 0.0, 3.0, 4.0, 5.0, 0.0, 6.0, 7.0])
165 .unwrap();
166
167 let banded = BandedMatrix::from_dense(&dense, 1, 1).unwrap();
168
169 assert_eq!(banded.get(0, 0), 1.0);
170 assert_eq!(banded.get(0, 1), 2.0);
171 assert_eq!(banded.get(1, 0), 3.0);
172 assert_eq!(banded.get(1, 1), 4.0);
173 assert_eq!(banded.get(1, 2), 5.0);
174 assert_eq!(banded.get(2, 1), 6.0);
175 assert_eq!(banded.get(2, 2), 7.0);
176
177 assert_eq!(banded.get(0, 2), 0.0);
179 assert_eq!(banded.get(2, 0), 0.0);
180 }
181}