1use crate::error::{LinalgError, LinalgResult};
4use scirs2_core::ndarray::{Array2, ArrayView2, ScalarOperand};
5use scirs2_core::numeric::{Float, NumAssign};
6use std::iter::Sum;
7
8#[allow(dead_code)]
30pub fn det<F>(a: &ArrayView2<F>, workers: Option<usize>) -> LinalgResult<F>
31where
32 F: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
33{
34 use crate::parallel;
35
36 parallel::configure_workers(workers);
38
39 if a.nrows() != a.ncols() {
40 let rows = a.nrows();
41 let cols = a.ncols();
42 return Err(LinalgError::ShapeError(format!(
43 "Determinant computation failed: Matrix must be square\nMatrix shape: {rows}×{cols}\nExpected: Square matrix (n×n)"
44 )));
45 }
46
47 match a.nrows() {
49 0 => Ok(F::one()),
50 1 => Ok(a[[0, 0]]),
51 2 => Ok(a[[0, 0]] * a[[1, 1]] - a[[0, 1]] * a[[1, 0]]),
52 3 => {
53 let det = a[[0, 0]] * (a[[1, 1]] * a[[2, 2]] - a[[1, 2]] * a[[2, 1]])
54 - a[[0, 1]] * (a[[1, 0]] * a[[2, 2]] - a[[1, 2]] * a[[2, 0]])
55 + a[[0, 2]] * (a[[1, 0]] * a[[2, 1]] - a[[1, 1]] * a[[2, 0]]);
56 Ok(det)
57 }
58 _ => {
59 use crate::decomposition::lu;
61
62 match lu(a, workers) {
63 Ok((p, _l, u)) => {
64 let mut det_u = F::one();
66 for i in 0..u.nrows() {
67 det_u *= u[[i, i]];
68 }
69
70 let mut swap_count = 0;
72 for i in 0..p.nrows() {
73 for j in 0..i {
74 if p[[i, j]] == F::one() {
75 swap_count += 1;
76 }
77 }
78 }
79
80 if swap_count % 2 == 0 {
82 Ok(det_u)
83 } else {
84 Ok(-det_u)
85 }
86 }
87 Err(LinalgError::SingularMatrixError(_)) => {
88 Ok(F::zero())
90 }
91 Err(e) => Err(e),
92 }
93 }
94 }
95}
96
97#[allow(dead_code)]
120pub fn inv<F>(a: &ArrayView2<F>, workers: Option<usize>) -> LinalgResult<Array2<F>>
121where
122 F: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
123{
124 use crate::parallel;
125
126 parallel::configure_workers(workers);
128
129 if a.nrows() != a.ncols() {
130 let rows = a.nrows();
131 let cols = a.ncols();
132 return Err(LinalgError::ShapeError(format!(
133 "Matrix inverse computation failed: Matrix must be square\nMatrix shape: {rows}×{cols}\nExpected: Square matrix (n×n)"
134 )));
135 }
136
137 if a.nrows() == 2 {
139 let det_val = det(a, workers)?;
140 if det_val.abs() < F::epsilon() {
141 let norm_a = (a[[0, 0]] * a[[0, 0]]
143 + a[[0, 1]] * a[[0, 1]]
144 + a[[1, 0]] * a[[1, 0]]
145 + a[[1, 1]] * a[[1, 1]])
146 .sqrt();
147 let cond_estimate = if det_val.abs() > F::zero() {
148 Some((norm_a / det_val.abs()).to_f64().unwrap_or(1e16))
149 } else {
150 None
151 };
152
153 return Err(LinalgError::singularmatrix_with_suggestions(
154 "matrix inverse",
155 a.dim(),
156 cond_estimate,
157 ));
158 }
159
160 let inv_det = F::one() / det_val;
161 let mut result = Array2::zeros((2, 2));
162 result[[0, 0]] = a[[1, 1]] * inv_det;
163 result[[0, 1]] = -a[[0, 1]] * inv_det;
164 result[[1, 0]] = -a[[1, 0]] * inv_det;
165 result[[1, 1]] = a[[0, 0]] * inv_det;
166 return Ok(result);
167 }
168
169 use crate::solve::solve_multiple;
171
172 let n = a.nrows();
173 let mut identity = Array2::zeros((n, n));
174 for i in 0..n {
175 identity[[i, i]] = F::one();
176 }
177
178 match solve_multiple(a, &identity.view(), workers) {
180 Err(LinalgError::SingularMatrixError(_)) => {
181 Err(LinalgError::singularmatrix_with_suggestions(
183 "matrix inverse via solve",
184 a.dim(),
185 None, ))
187 }
188 other => other,
189 }
190}
191
192#[allow(dead_code)]
220pub fn matrix_power<F>(a: &ArrayView2<F>, n: i32, workers: Option<usize>) -> LinalgResult<Array2<F>>
221where
222 F: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
223{
224 use crate::parallel;
225
226 parallel::configure_workers(workers);
228
229 if a.nrows() != a.ncols() {
230 let rows = a.nrows();
231 let cols = a.ncols();
232 return Err(LinalgError::ShapeError(format!(
233 "Matrix power computation failed: Matrix must be square\nMatrix shape: {rows}×{cols}\nExpected: Square matrix (n×n)"
234 )));
235 }
236
237 let dim = a.nrows();
238
239 if n == 0 {
241 let mut result = Array2::zeros((dim, dim));
243 for i in 0..dim {
244 result[[i, i]] = F::one();
245 }
246 return Ok(result);
247 }
248
249 if n == 1 {
250 return Ok(a.to_owned());
252 }
253
254 if n == -1 {
255 return inv(a, workers);
257 }
258
259 if n.abs() > 1 {
260 return Err(LinalgError::NotImplementedError(
264 "Matrix power for |n| > 1 not yet implemented".to_string(),
265 ));
266 }
267
268 Err(LinalgError::ComputationError(
270 "Unexpected error in matrix power calculation".to_string(),
271 ))
272}
273
274#[allow(dead_code)]
297pub fn trace<F>(a: &ArrayView2<F>) -> LinalgResult<F>
298where
299 F: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
300{
301 if a.nrows() != a.ncols() {
302 let rows = a.nrows();
303 let cols = a.ncols();
304 return Err(LinalgError::ShapeError(format!(
305 "Matrix trace computation failed: Matrix must be square\nMatrix shape: {rows}×{cols}\nExpected: Square matrix (n×n)"
306 )));
307 }
308
309 let mut tr = F::zero();
310 for i in 0..a.nrows() {
311 tr += a[[i, i]];
312 }
313
314 Ok(tr)
315}
316
317#[allow(dead_code)]
326pub fn det_default<F>(a: &ArrayView2<F>) -> LinalgResult<F>
327where
328 F: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
329{
330 det(a, None)
331}
332
333#[allow(dead_code)]
338pub fn inv_default<F>(a: &ArrayView2<F>) -> LinalgResult<Array2<F>>
339where
340 F: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
341{
342 inv(a, None)
343}
344
345#[allow(dead_code)]
350pub fn matrix_power_default<F>(a: &ArrayView2<F>, n: i32) -> LinalgResult<Array2<F>>
351where
352 F: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
353{
354 matrix_power(a, n, None)
355}
356
357#[cfg(test)]
358mod tests {
359 use super::*;
360 use approx::assert_relative_eq;
361 use scirs2_core::ndarray::array;
362
363 #[test]
364 fn test_det_2x2() {
365 let a = array![[1.0, 2.0], [3.0, 4.0]];
366 let d = det(&a.view(), None).unwrap();
367 assert!((d - (-2.0)).abs() < 1e-10);
368
369 let b = array![[2.0, 0.0], [0.0, 3.0]];
370 let d = det(&b.view(), None).unwrap();
371 assert!((d - 6.0).abs() < 1e-10);
372 }
373
374 #[test]
375 fn test_det_3x3() {
376 let a = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]];
377 let d = det(&a.view(), None).unwrap();
378 assert!((d - 0.0).abs() < 1e-10);
379
380 let b = array![[2.0, 0.0, 0.0], [0.0, 3.0, 0.0], [0.0, 0.0, 4.0]];
381 let d = det(&b.view(), None).unwrap();
382 assert!((d - 24.0).abs() < 1e-10);
383 }
384
385 #[test]
386 fn test_inv_2x2() {
387 let a = array![[1.0, 0.0], [0.0, 2.0]];
388 let a_inv = inv(&a.view(), None).unwrap();
389 assert_relative_eq!(a_inv[[0, 0]], 1.0);
390 assert_relative_eq!(a_inv[[0, 1]], 0.0);
391 assert_relative_eq!(a_inv[[1, 0]], 0.0);
392 assert_relative_eq!(a_inv[[1, 1]], 0.5);
393
394 let b = array![[1.0, 2.0], [3.0, 4.0]];
395 let b_inv = inv(&b.view(), None).unwrap();
396 assert_relative_eq!(b_inv[[0, 0]], -2.0);
397 assert_relative_eq!(b_inv[[0, 1]], 1.0);
398 assert_relative_eq!(b_inv[[1, 0]], 1.5);
399 assert_relative_eq!(b_inv[[1, 1]], -0.5);
400 }
401
402 #[test]
403 fn test_inv_large() {
404 let a = array![[1.0, 2.0, 3.0], [0.0, 1.0, 4.0], [5.0, 6.0, 0.0]];
406 let a_inv = inv(&a.view(), None).unwrap();
407
408 let product = a.dot(&a_inv);
410 let n = a.nrows();
411 for i in 0..n {
412 for j in 0..n {
413 if i == j {
414 assert_relative_eq!(product[[i, j]], 1.0, epsilon = 1e-10);
415 } else {
416 assert_relative_eq!(product[[i, j]], 0.0, epsilon = 1e-10);
417 }
418 }
419 }
420
421 let b = array![
423 [2.0, 0.0, 0.0, 0.0],
424 [0.0, 3.0, 0.0, 0.0],
425 [0.0, 0.0, 4.0, 0.0],
426 [0.0, 0.0, 0.0, 5.0]
427 ];
428 let b_inv = inv(&b.view(), None).unwrap();
429 assert_relative_eq!(b_inv[[0, 0]], 0.5, epsilon = 1e-10);
430 assert_relative_eq!(b_inv[[1, 1]], 1.0 / 3.0, epsilon = 1e-10);
431 assert_relative_eq!(b_inv[[2, 2]], 0.25, epsilon = 1e-10);
432 assert_relative_eq!(b_inv[[3, 3]], 0.2, epsilon = 1e-10);
433
434 let c = array![[1.0, 2.0, 3.0], [2.0, 4.0, 6.0], [3.0, 6.0, 9.0]];
436 assert!(inv(&c.view(), None).is_err());
437 }
438
439 #[test]
440 fn testmatrix_power() {
441 let a = array![[1.0, 2.0], [3.0, 4.0]];
442
443 let a_0 = matrix_power(&a.view(), 0, None).unwrap();
445 assert_relative_eq!(a_0[[0, 0]], 1.0);
446 assert_relative_eq!(a_0[[0, 1]], 0.0);
447 assert_relative_eq!(a_0[[1, 0]], 0.0);
448 assert_relative_eq!(a_0[[1, 1]], 1.0);
449
450 let a_1 = matrix_power(&a.view(), 1, None).unwrap();
452 assert_relative_eq!(a_1[[0, 0]], a[[0, 0]]);
453 assert_relative_eq!(a_1[[0, 1]], a[[0, 1]]);
454 assert_relative_eq!(a_1[[1, 0]], a[[1, 0]]);
455 assert_relative_eq!(a_1[[1, 1]], a[[1, 1]]);
456 }
457
458 #[test]
459 fn test_det_large() {
460 let a = array![
462 [2.0, 1.0, 0.0, 0.0],
463 [1.0, 2.0, 1.0, 0.0],
464 [0.0, 1.0, 2.0, 1.0],
465 [0.0, 0.0, 1.0, 2.0]
466 ];
467 let d = det(&a.view(), None).unwrap();
468 assert_relative_eq!(d, 5.0, epsilon = 1e-10);
469
470 let b = array![
472 [1.0, 0.0, 0.0, 0.0, 0.0],
473 [0.0, 2.0, 0.0, 0.0, 0.0],
474 [0.0, 0.0, 3.0, 0.0, 0.0],
475 [0.0, 0.0, 0.0, 4.0, 0.0],
476 [0.0, 0.0, 0.0, 0.0, 5.0]
477 ];
478 let d = det(&b.view(), None).unwrap();
479 assert_relative_eq!(d, 120.0, epsilon = 1e-10);
480
481 let c = array![
483 [1.0, 2.0, 3.0, 4.0],
484 [2.0, 4.0, 6.0, 8.0],
485 [3.0, 6.0, 9.0, 12.0],
486 [4.0, 8.0, 12.0, 16.0]
487 ];
488 let d = det(&c.view(), None).unwrap();
489 assert_relative_eq!(d, 0.0, epsilon = 1e-10);
490 }
491}