scirs2_linalg/matrix_functions/
utils.rs1use scirs2_core::ndarray::{Array2, ArrayView2};
4use scirs2_core::numeric::{Float, NumAssign, One};
5use std::iter::Sum;
6
7use crate::error::{LinalgError, LinalgResult};
8
9pub fn is_integer<F: Float>(x: F) -> bool {
11 (x - x.round()).abs() < F::from(1e-10).unwrap_or(F::epsilon())
12}
13
14pub fn is_diagonal<F>(a: &ArrayView2<F>) -> bool
16where
17 F: Float + NumAssign + Sum + One + Send + Sync + scirs2_core::ndarray::ScalarOperand + 'static,
18{
19 let n = a.nrows();
20 for i in 0..n {
21 for j in 0..n {
22 if i != j && a[[i, j]].abs() > F::epsilon() {
23 return false;
24 }
25 }
26 }
27 true
28}
29
30pub fn is_symmetric<F>(a: &ArrayView2<F>) -> bool
32where
33 F: Float + NumAssign + Sum + One + Send + Sync + scirs2_core::ndarray::ScalarOperand + 'static,
34{
35 let n = a.nrows();
36 if n != a.ncols() {
37 return false;
38 }
39
40 for i in 0..n {
41 for j in 0..n {
42 if (a[[i, j]] - a[[j, i]]).abs() > F::epsilon() {
43 return false;
44 }
45 }
46 }
47 true
48}
49
50pub fn is_zero_matrix<F>(a: &ArrayView2<F>) -> bool
52where
53 F: Float + NumAssign + Sum + One + Send + Sync + scirs2_core::ndarray::ScalarOperand + 'static,
54{
55 let (m, n) = a.dim();
56 for i in 0..m {
57 for j in 0..n {
58 if a[[i, j]].abs() > F::epsilon() {
59 return false;
60 }
61 }
62 }
63 true
64}
65
66pub fn is_identity<F>(a: &ArrayView2<F>) -> bool
68where
69 F: Float + NumAssign + Sum + One + Send + Sync + scirs2_core::ndarray::ScalarOperand + 'static,
70{
71 let n = a.nrows();
72 if n != a.ncols() {
73 return false;
74 }
75
76 for i in 0..n {
77 for j in 0..n {
78 let expected = if i == j { F::one() } else { F::zero() };
79 if (a[[i, j]] - expected).abs() > F::epsilon() {
80 return false;
81 }
82 }
83 }
84 true
85}
86
87pub fn matrix_multiply<F>(a: &ArrayView2<F>, b: &ArrayView2<F>) -> LinalgResult<Array2<F>>
89where
90 F: Float + NumAssign + Sum + One + Send + Sync + scirs2_core::ndarray::ScalarOperand + 'static,
91{
92 let (m, k1) = a.dim();
93 let (k2, n) = b.dim();
94
95 if k1 != k2 {
96 return Err(LinalgError::ShapeError(format!(
97 "Matrix dimensions incompatible for multiplication: ({}, {}) × ({}, {})",
98 m, k1, k2, n
99 )));
100 }
101
102 let mut c = Array2::<F>::zeros((m, n));
103 for i in 0..m {
104 for j in 0..n {
105 for k in 0..k1 {
106 c[[i, j]] += a[[i, k]] * b[[k, j]];
107 }
108 }
109 }
110
111 Ok(c)
112}
113
114pub fn integer_matrix_power<F>(a: &ArrayView2<F>, p: i32) -> LinalgResult<Array2<F>>
116where
117 F: Float + NumAssign + Sum + One + Send + Sync + scirs2_core::ndarray::ScalarOperand + 'static,
118{
119 use crate::solve::solve_multiple;
120
121 let n = a.nrows();
122
123 if p == 0 {
124 return Ok(Array2::eye(n));
125 }
126
127 if p == 1 {
128 return Ok(a.to_owned());
129 }
130
131 if p < 0 {
132 let a_inv = solve_multiple(a, &Array2::eye(n).view(), None)?;
134 return integer_matrix_power(&a_inv.view(), -p);
135 }
136
137 let mut result = Array2::eye(n);
139 let mut base = a.to_owned();
140 let mut exp = p as u32;
141
142 while exp > 0 {
143 if exp % 2 == 1 {
144 result = matrix_multiply(&result.view(), &base.view())?;
145 }
146 base = matrix_multiply(&base.view(), &base.view())?;
147 exp /= 2;
148 }
149
150 Ok(result)
151}
152
153pub fn frobenius_norm<F>(a: &ArrayView2<F>) -> F
155where
156 F: Float + NumAssign + Sum + One + Send + Sync + scirs2_core::ndarray::ScalarOperand + 'static,
157{
158 let (m, n) = a.dim();
159 let mut sum = F::zero();
160
161 for i in 0..m {
162 for j in 0..n {
163 sum += a[[i, j]] * a[[i, j]];
164 }
165 }
166
167 sum.sqrt()
168}
169
170pub fn matrix_diff_norm<F>(a: &ArrayView2<F>, b: &ArrayView2<F>) -> LinalgResult<F>
172where
173 F: Float + NumAssign + Sum + One + Send + Sync + scirs2_core::ndarray::ScalarOperand + 'static,
174{
175 if a.dim() != b.dim() {
176 return Err(LinalgError::ShapeError(
177 "Matrices must have the same dimensions".to_string(),
178 ));
179 }
180
181 let (m, n) = a.dim();
182 let mut max_diff = F::zero();
183
184 for i in 0..m {
185 for j in 0..n {
186 let diff = (a[[i, j]] - b[[i, j]]).abs();
187 if diff > max_diff {
188 max_diff = diff;
189 }
190 }
191 }
192
193 Ok(max_diff)
194}
195
196pub fn scale_matrix<F>(a: &ArrayView2<F>, alpha: F) -> Array2<F>
198where
199 F: Float + NumAssign + Sum + One + Send + Sync + scirs2_core::ndarray::ScalarOperand + 'static,
200{
201 let (m, n) = a.dim();
202 let mut result = Array2::<F>::zeros((m, n));
203
204 for i in 0..m {
205 for j in 0..n {
206 result[[i, j]] = alpha * a[[i, j]];
207 }
208 }
209
210 result
211}
212
213pub fn matrix_add<F>(a: &ArrayView2<F>, b: &ArrayView2<F>) -> LinalgResult<Array2<F>>
215where
216 F: Float + NumAssign + Sum + One + Send + Sync + scirs2_core::ndarray::ScalarOperand + 'static,
217{
218 if a.dim() != b.dim() {
219 return Err(LinalgError::ShapeError(
220 "Matrices must have the same dimensions for addition".to_string(),
221 ));
222 }
223
224 let (m, n) = a.dim();
225 let mut result = Array2::<F>::zeros((m, n));
226
227 for i in 0..m {
228 for j in 0..n {
229 result[[i, j]] = a[[i, j]] + b[[i, j]];
230 }
231 }
232
233 Ok(result)
234}
235
236pub fn matrix_subtract<F>(a: &ArrayView2<F>, b: &ArrayView2<F>) -> LinalgResult<Array2<F>>
238where
239 F: Float + NumAssign + Sum + One + Send + Sync + scirs2_core::ndarray::ScalarOperand + 'static,
240{
241 if a.dim() != b.dim() {
242 return Err(LinalgError::ShapeError(
243 "Matrices must have the same dimensions for subtraction".to_string(),
244 ));
245 }
246
247 let (m, n) = a.dim();
248 let mut result = Array2::<F>::zeros((m, n));
249
250 for i in 0..m {
251 for j in 0..n {
252 result[[i, j]] = a[[i, j]] - b[[i, j]];
253 }
254 }
255
256 Ok(result)
257}
258
259pub fn matrix_transpose<F>(a: &ArrayView2<F>) -> Array2<F>
261where
262 F: Float + NumAssign + Sum + One + Send + Sync + scirs2_core::ndarray::ScalarOperand + 'static,
263{
264 let (m, n) = a.dim();
265 let mut result = Array2::<F>::zeros((n, m));
266
267 for i in 0..m {
268 for j in 0..n {
269 result[[j, i]] = a[[i, j]];
270 }
271 }
272
273 result
274}
275
276pub fn check_positive_definite<F>(eigenvals: &[F]) -> bool
278where
279 F: Float,
280{
281 eigenvals.iter().all(|&val| val > F::zero())
282}
283
284pub fn check_positive_semidefinite<F>(eigenvals: &[F]) -> bool
286where
287 F: Float,
288{
289 eigenvals.iter().all(|&val| val >= F::zero())
290}