1use scirs2_core::ndarray::{Array2, ArrayView2};
4use scirs2_core::numeric::{Float, NumAssign, One};
5use std::iter::Sum;
6
7use crate::eigen::eig;
8use crate::error::{LinalgError, LinalgResult};
9use crate::norm::matrix_norm;
10use crate::solve::solve_multiple;
11use crate::validation::validate_decomposition;
12
13#[allow(dead_code)]
44pub fn expm<F>(a: &ArrayView2<F>, workers: Option<usize>) -> LinalgResult<Array2<F>>
45where
46 F: Float + NumAssign + Sum + One + scirs2_core::ndarray::ScalarOperand + Send + Sync + 'static,
47{
48 use crate::parallel;
49
50 parallel::configure_workers(workers);
52
53 validate_decomposition(a, "Matrix exponential computation", true)?;
55
56 let n = a.nrows();
57
58 if n == 1 {
60 let mut result = Array2::<F>::zeros((1, 1));
61 result[[0, 0]] = a[[0, 0]].exp();
62 return Ok(result);
63 }
64
65 let mut is_diagonal = true;
67 for i in 0..n {
68 for j in 0..n {
69 if i != j && a[[i, j]].abs() > F::epsilon() {
70 is_diagonal = false;
71 break;
72 }
73 }
74 if !is_diagonal {
75 break;
76 }
77 }
78
79 if is_diagonal {
80 let mut result = Array2::<F>::zeros((n, n));
81 for i in 0..n {
82 result[[i, i]] = a[[i, i]].exp();
83 }
84 return Ok(result);
85 }
86
87 let norm_a = matrix_norm(a, "1", None)?;
89 let scaling_f = norm_a.log2().ceil().max(F::zero());
90 let scaling = scaling_f.to_i32().unwrap_or(0);
91 let s = F::from(2.0_f64.powi(-scaling)).unwrap_or(F::one());
92
93 let mut a_scaled = Array2::<F>::zeros((n, n));
95 for i in 0..n {
96 for j in 0..n {
97 a_scaled[[i, j]] = a[[i, j]] * s;
98 }
99 }
100
101 let c = [
103 F::from(1.0).unwrap(),
104 F::from(1.0 / 2.0).unwrap(),
105 F::from(1.0 / 6.0).unwrap(),
106 F::from(1.0 / 24.0).unwrap(),
107 F::from(1.0 / 120.0).unwrap(),
108 F::from(1.0 / 720.0).unwrap(),
109 ];
110
111 let mut a2 = Array2::<F>::zeros((n, n));
113 for i in 0..n {
114 for j in 0..n {
115 for k in 0..n {
116 a2[[i, j]] += a_scaled[[i, k]] * a_scaled[[k, j]];
117 }
118 }
119 }
120
121 let mut a4 = Array2::<F>::zeros((n, n));
122 for i in 0..n {
123 for j in 0..n {
124 for k in 0..n {
125 a4[[i, j]] += a2[[i, k]] * a2[[k, j]];
126 }
127 }
128 }
129
130 let mut n_pade = Array2::<F>::zeros((n, n));
132 for i in 0..n {
133 n_pade[[i, i]] = c[0]; }
135
136 for i in 0..n {
138 for j in 0..n {
139 n_pade[[i, j]] += c[1] * a_scaled[[i, j]];
140 }
141 }
142
143 for i in 0..n {
145 for j in 0..n {
146 n_pade[[i, j]] += c[2] * a2[[i, j]];
147 }
148 }
149
150 let mut a3 = Array2::<F>::zeros((n, n));
152 for i in 0..n {
153 for j in 0..n {
154 for k in 0..n {
155 a3[[i, j]] += a_scaled[[i, k]] * a2[[k, j]];
156 }
157 }
158 }
159
160 for i in 0..n {
161 for j in 0..n {
162 n_pade[[i, j]] += c[3] * a3[[i, j]];
163 }
164 }
165
166 for i in 0..n {
168 for j in 0..n {
169 n_pade[[i, j]] += c[4] * a4[[i, j]];
170 }
171 }
172
173 let mut a5 = Array2::<F>::zeros((n, n));
175 for i in 0..n {
176 for j in 0..n {
177 for k in 0..n {
178 a5[[i, j]] += a_scaled[[i, k]] * a4[[k, j]];
179 }
180 }
181 }
182
183 for i in 0..n {
184 for j in 0..n {
185 n_pade[[i, j]] += c[5] * a5[[i, j]];
186 }
187 }
188
189 let mut d_pade = Array2::<F>::zeros((n, n));
191 for i in 0..n {
192 d_pade[[i, i]] = c[0]; }
194
195 for i in 0..n {
197 for j in 0..n {
198 d_pade[[i, j]] -= c[1] * a_scaled[[i, j]];
199 }
200 }
201
202 for i in 0..n {
204 for j in 0..n {
205 d_pade[[i, j]] += c[2] * a2[[i, j]];
206 }
207 }
208
209 for i in 0..n {
211 for j in 0..n {
212 d_pade[[i, j]] -= c[3] * a3[[i, j]];
213 }
214 }
215
216 for i in 0..n {
218 for j in 0..n {
219 d_pade[[i, j]] += c[4] * a4[[i, j]];
220 }
221 }
222
223 for i in 0..n {
225 for j in 0..n {
226 d_pade[[i, j]] -= c[5] * a5[[i, j]];
227 }
228 }
229
230 let result = solve_multiple(&d_pade.view(), &n_pade.view(), None)?;
232
233 let mut exp_a = result;
235
236 for _ in 0..scaling as usize {
237 let mut temp = Array2::<F>::zeros((n, n));
238 for i in 0..n {
239 for j in 0..n {
240 for k in 0..n {
241 temp[[i, j]] += exp_a[[i, k]] * exp_a[[k, j]];
242 }
243 }
244 }
245 exp_a = temp;
246 }
247
248 Ok(exp_a)
249}
250
251#[allow(dead_code)]
282pub fn logm<F>(a: &ArrayView2<F>) -> LinalgResult<Array2<F>>
283where
284 F: Float + NumAssign + Sum + One + Send + Sync + scirs2_core::ndarray::ScalarOperand + 'static,
285{
286 logm_impl(a)
287}
288
289#[allow(dead_code)]
291fn logm_impl<F>(a: &ArrayView2<F>) -> LinalgResult<Array2<F>>
292where
293 F: Float + NumAssign + Sum + One + Send + Sync + scirs2_core::ndarray::ScalarOperand + 'static,
294{
295 if a.nrows() != a.ncols() {
296 return Err(LinalgError::ShapeError(format!(
297 "Matrix must be square to compute logarithm, got shape {:?}",
298 a.shape()
299 )));
300 }
301
302 let n = a.nrows();
303
304 if n == 1 {
306 let val = a[[0, 0]];
307 if val <= F::zero() {
308 return Err(LinalgError::InvalidInputError(
309 "Cannot compute real logarithm of non-positive scalar".to_string(),
310 ));
311 }
312
313 let mut result = Array2::<F>::zeros((1, 1));
314 result[[0, 0]] = val.ln();
315 return Ok(result);
316 }
317
318 let mut is_diagonal = true;
320 for i in 0..n {
321 for j in 0..n {
322 if i != j && a[[i, j]].abs() > F::epsilon() {
323 is_diagonal = false;
324 break;
325 }
326 }
327 if !is_diagonal {
328 break;
329 }
330 }
331
332 if is_diagonal {
333 for i in 0..n {
335 if a[[i, i]] <= F::zero() {
336 return Err(LinalgError::InvalidInputError(
337 "Cannot compute real logarithm of matrix with non-positive eigenvalues"
338 .to_string(),
339 ));
340 }
341 }
342
343 let mut result = Array2::<F>::zeros((n, n));
344 for i in 0..n {
345 result[[i, i]] = a[[i, i]].ln();
346 }
347 return Ok(result);
348 }
349
350 let mut is_identity = true;
352 for i in 0..n {
353 for j in 0..n {
354 let expected = if i == j { F::one() } else { F::zero() };
355 if (a[[i, j]] - expected).abs() > F::epsilon() {
356 is_identity = false;
357 break;
358 }
359 }
360 if !is_identity {
361 break;
362 }
363 }
364
365 if is_identity {
367 return Ok(Array2::<F>::zeros((n, n)));
368 }
369
370 if n == 2 && a[[0, 1]].abs() < F::epsilon() && a[[1, 0]].abs() < F::epsilon() {
372 let a00 = a[[0, 0]];
373 let a11 = a[[1, 1]];
374
375 if a00 <= F::zero() || a11 <= F::zero() {
376 return Err(LinalgError::InvalidInputError(
377 "Cannot compute real logarithm of matrix with non-positive eigenvalues".to_string(),
378 ));
379 }
380
381 let mut result = Array2::<F>::zeros((2, 2));
382 result[[0, 0]] = a00.ln();
383 result[[1, 1]] = a11.ln();
384 return Ok(result);
385 }
386
387 let identity = Array2::eye(n);
393 let mut max_diff = F::zero();
394 for i in 0..n {
395 for j in 0..n {
396 let diff = (a[[i, j]] - identity[[i, j]]).abs();
397 if diff > max_diff {
398 max_diff = diff;
399 }
400 }
401 }
402
403 if max_diff > F::from(0.5).unwrap() {
405 let mut scaling_k = 0;
411 let mut a_scaled = a.to_owned();
412
413 while scaling_k < 10 {
416 let mut max_scaled_diff = F::zero();
418 for i in 0..n {
419 for j in 0..n {
420 let expected = if i == j { F::one() } else { F::zero() };
421 let diff = (a_scaled[[i, j]] - expected).abs();
422 if diff > max_scaled_diff {
423 max_scaled_diff = diff;
424 }
425 }
426 }
427
428 if max_scaled_diff <= F::from(0.2).unwrap() {
429 break;
430 }
431
432 match sqrtm(&a_scaled.view(), 20, F::from(1e-12).unwrap()) {
434 Ok(sqrt_result) => {
435 a_scaled = sqrt_result;
436 scaling_k += 1;
437 }
438 Err(_) => {
439 return Err(LinalgError::ImplementationError(
440 "Matrix logarithm: Could not compute matrix square root for scaling"
441 .to_string(),
442 ));
443 }
444 }
445 }
446
447 if scaling_k >= 10 {
448 return Err(LinalgError::ImplementationError(
449 "Matrix logarithm: Matrix could not be scaled close enough to identity".to_string(),
450 ));
451 }
452
453 let mut x_scaled = Array2::<F>::zeros((n, n));
455 for i in 0..n {
456 for j in 0..n {
457 let expected = if i == j { F::one() } else { F::zero() };
458 x_scaled[[i, j]] = a_scaled[[i, j]] - expected;
459 }
460 }
461
462 let mut x2 = Array2::<F>::zeros((n, n));
464 for i in 0..n {
465 for j in 0..n {
466 for k in 0..n {
467 x2[[i, j]] += x_scaled[[i, k]] * x_scaled[[k, j]];
468 }
469 }
470 }
471
472 let mut x3 = Array2::<F>::zeros((n, n));
473 for i in 0..n {
474 for j in 0..n {
475 for k in 0..n {
476 x3[[i, j]] += x2[[i, k]] * x_scaled[[k, j]];
477 }
478 }
479 }
480
481 let mut x4 = Array2::<F>::zeros((n, n));
482 for i in 0..n {
483 for j in 0..n {
484 for k in 0..n {
485 x4[[i, j]] += x3[[i, k]] * x_scaled[[k, j]];
486 }
487 }
488 }
489
490 let mut x5 = Array2::<F>::zeros((n, n));
491 for i in 0..n {
492 for j in 0..n {
493 for k in 0..n {
494 x5[[i, j]] += x4[[i, k]] * x_scaled[[k, j]];
495 }
496 }
497 }
498
499 let mut x6 = Array2::<F>::zeros((n, n));
500 for i in 0..n {
501 for j in 0..n {
502 for k in 0..n {
503 x6[[i, j]] += x5[[i, k]] * x_scaled[[k, j]];
504 }
505 }
506 }
507
508 let mut log_scaled = Array2::<F>::zeros((n, n));
511 let half = F::from(0.5).unwrap();
512 let third = F::from(1.0 / 3.0).unwrap();
513 let fourth = F::from(0.25).unwrap();
514 let fifth = F::from(0.2).unwrap();
515 let sixth = F::from(1.0 / 6.0).unwrap();
516
517 for i in 0..n {
518 for j in 0..n {
519 log_scaled[[i, j]] = x_scaled[[i, j]] - half * x2[[i, j]] + third * x3[[i, j]]
520 - fourth * x4[[i, j]]
521 + fifth * x5[[i, j]]
522 - sixth * x6[[i, j]];
523 }
524 }
525
526 let scale_factor = F::from(2.0_f64.powi(scaling_k)).unwrap();
528 for i in 0..n {
529 for j in 0..n {
530 log_scaled[[i, j]] *= scale_factor;
531 }
532 }
533
534 return Ok(log_scaled);
535 }
536
537 let mut x = Array2::<F>::zeros((n, n));
540 for i in 0..n {
541 for j in 0..n {
542 x[[i, j]] = a[[i, j]] - identity[[i, j]];
543 }
544 }
545
546 let mut x2 = Array2::<F>::zeros((n, n));
548 for i in 0..n {
549 for j in 0..n {
550 for k in 0..n {
551 x2[[i, j]] += x[[i, k]] * x[[k, j]];
552 }
553 }
554 }
555
556 let mut x3 = Array2::<F>::zeros((n, n));
557 for i in 0..n {
558 for j in 0..n {
559 for k in 0..n {
560 x3[[i, j]] += x2[[i, k]] * x[[k, j]];
561 }
562 }
563 }
564
565 let mut x4 = Array2::<F>::zeros((n, n));
566 for i in 0..n {
567 for j in 0..n {
568 for k in 0..n {
569 x4[[i, j]] += x3[[i, k]] * x[[k, j]];
570 }
571 }
572 }
573
574 let mut x5 = Array2::<F>::zeros((n, n));
575 for i in 0..n {
576 for j in 0..n {
577 for k in 0..n {
578 x5[[i, j]] += x4[[i, k]] * x[[k, j]];
579 }
580 }
581 }
582
583 let mut x6 = Array2::<F>::zeros((n, n));
584 for i in 0..n {
585 for j in 0..n {
586 for k in 0..n {
587 x6[[i, j]] += x5[[i, k]] * x[[k, j]];
588 }
589 }
590 }
591
592 let mut result = Array2::<F>::zeros((n, n));
594 let half = F::from(0.5).unwrap();
595 let third = F::from(1.0 / 3.0).unwrap();
596 let fourth = F::from(0.25).unwrap();
597 let fifth = F::from(0.2).unwrap();
598 let sixth = F::from(1.0 / 6.0).unwrap();
599
600 for i in 0..n {
601 for j in 0..n {
602 result[[i, j]] = x[[i, j]] - half * x2[[i, j]] + third * x3[[i, j]]
603 - fourth * x4[[i, j]]
604 + fifth * x5[[i, j]]
605 - sixth * x6[[i, j]];
606 }
607 }
608
609 Ok(result)
610}
611
612#[allow(dead_code)]
639pub fn logm_parallel<F>(a: &ArrayView2<F>, workers: Option<usize>) -> LinalgResult<Array2<F>>
640where
641 F: Float + NumAssign + Sum + One + Send + Sync + scirs2_core::ndarray::ScalarOperand + 'static,
642{
643 use crate::parallel;
644
645 parallel::configure_workers(workers);
647
648 const PARALLEL_THRESHOLD: usize = 50; if a.nrows() < PARALLEL_THRESHOLD || a.ncols() < PARALLEL_THRESHOLD {
652 return logm(a);
654 }
655
656 logm_impl_parallel(a)
658}
659
660#[allow(dead_code)]
662fn logm_impl_parallel<F>(a: &ArrayView2<F>) -> LinalgResult<Array2<F>>
663where
664 F: Float + NumAssign + Sum + One + Send + Sync + scirs2_core::ndarray::ScalarOperand + 'static,
665{
666 logm_impl(a)
669}
670
671#[allow(dead_code)]
702pub fn sqrtm<F>(a: &ArrayView2<F>, maxiter: usize, tol: F) -> LinalgResult<Array2<F>>
703where
704 F: Float + NumAssign + Sum + One + Send + Sync + scirs2_core::ndarray::ScalarOperand + 'static,
705{
706 sqrtm_impl(a, maxiter, tol)
707}
708
709#[allow(dead_code)]
711fn sqrtm_impl<F>(a: &ArrayView2<F>, maxiter: usize, tol: F) -> LinalgResult<Array2<F>>
712where
713 F: Float + NumAssign + Sum + One + Send + Sync + scirs2_core::ndarray::ScalarOperand + 'static,
714{
715 validate_decomposition(a, "Matrix square root computation", true)?;
716
717 let n = a.nrows();
718
719 if n == 1 {
721 let val = a[[0, 0]];
722 if val < F::zero() {
723 return Err(LinalgError::InvalidInputError(
724 "Cannot compute real square root of negative number".to_string(),
725 ));
726 }
727 let mut result = Array2::<F>::zeros((1, 1));
728 result[[0, 0]] = val.sqrt();
729 return Ok(result);
730 }
731
732 let mut is_diagonal = true;
734 for i in 0..n {
735 for j in 0..n {
736 if i != j && a[[i, j]].abs() > F::epsilon() {
737 is_diagonal = false;
738 break;
739 }
740 }
741 if !is_diagonal {
742 break;
743 }
744 }
745
746 if is_diagonal {
747 let mut result = Array2::<F>::zeros((n, n));
748 for i in 0..n {
749 if a[[i, i]] < F::zero() {
750 return Err(LinalgError::InvalidInputError(
751 "Cannot compute real square root of matrix with negative eigenvalues"
752 .to_string(),
753 ));
754 }
755 result[[i, i]] = a[[i, i]].sqrt();
756 }
757 return Ok(result);
758 }
759
760 let mut x = a.to_owned();
762 let mut y = Array2::eye(n);
763
764 for _ in 0..maxiter {
765 let x_prev = x.clone();
767
768 let y_inv = solve_multiple(&y.view(), &Array2::eye(n).view(), None)?;
774 let x_inv = solve_multiple(&x.view(), &Array2::eye(n).view(), None)?;
775
776 let mut x_new = Array2::<F>::zeros((n, n));
778 let mut y_new = Array2::<F>::zeros((n, n));
779
780 for i in 0..n {
781 for j in 0..n {
782 x_new[[i, j]] = (x[[i, j]] + y_inv[[i, j]]) * F::from(0.5).unwrap();
783 y_new[[i, j]] = (y[[i, j]] + x_inv[[i, j]]) * F::from(0.5).unwrap();
784 }
785 }
786
787 x = x_new;
788 y = y_new;
789
790 let mut max_diff = F::zero();
792 for i in 0..n {
793 for j in 0..n {
794 let diff = (x[[i, j]] - x_prev[[i, j]]).abs();
795 if diff > max_diff {
796 max_diff = diff;
797 }
798 }
799 }
800
801 if max_diff < tol {
802 break;
803 }
804 }
805
806 Ok(x)
807}
808
809#[allow(dead_code)]
822pub fn sqrtm_parallel<F>(
823 a: &ArrayView2<F>,
824 maxiter: usize,
825 tol: F,
826 workers: Option<usize>,
827) -> LinalgResult<Array2<F>>
828where
829 F: Float + NumAssign + Sum + One + Send + Sync + scirs2_core::ndarray::ScalarOperand + 'static,
830{
831 use crate::parallel;
832
833 parallel::configure_workers(workers);
835
836 const PARALLEL_THRESHOLD: usize = 50;
838 if a.nrows() < PARALLEL_THRESHOLD {
839 return sqrtm(a, maxiter, tol);
840 }
841
842 sqrtm_impl(a, maxiter, tol)
845}
846
847#[allow(dead_code)]
869pub fn matrix_power<F>(a: &ArrayView2<F>, p: F) -> LinalgResult<Array2<F>>
870where
871 F: Float + NumAssign + Sum + One + 'static + Send + Sync + scirs2_core::ndarray::ScalarOperand,
872{
873 validate_decomposition(a, "Matrix power computation", true)?;
874
875 let n = a.nrows();
876
877 if p.abs() < F::epsilon() {
879 return Ok(Array2::eye(n));
880 }
881
882 if (p - F::one()).abs() < F::epsilon() {
884 return Ok(a.to_owned());
885 }
886
887 let mut is_diagonal = true;
889 for i in 0..n {
890 for j in 0..n {
891 if i != j && a[[i, j]].abs() > F::epsilon() {
892 is_diagonal = false;
893 break;
894 }
895 }
896 if !is_diagonal {
897 break;
898 }
899 }
900
901 if is_diagonal {
902 let mut result = Array2::<F>::zeros((n, n));
903 for i in 0..n {
904 let val = a[[i, i]];
905 if val < F::zero() && !is_integer(p) {
906 return Err(LinalgError::InvalidInputError(
907 "Cannot compute real fractional power of negative number".to_string(),
908 ));
909 }
910 result[[i, i]] = val.powf(p);
911 }
912 return Ok(result);
913 }
914
915 if is_integer(p) {
917 let int_p = p.to_i32().unwrap_or(0);
918 if int_p >= 0 {
919 let mut result = Array2::eye(n);
921 let mut base = a.to_owned();
922 let mut exp = int_p as u32;
923
924 while exp > 0 {
925 if exp % 2 == 1 {
926 let mut temp = Array2::<F>::zeros((n, n));
928 for i in 0..n {
929 for j in 0..n {
930 for k in 0..n {
931 temp[[i, j]] += result[[i, k]] * base[[k, j]];
932 }
933 }
934 }
935 result = temp;
936 }
937 let mut temp = Array2::<F>::zeros((n, n));
939 for i in 0..n {
940 for j in 0..n {
941 for k in 0..n {
942 temp[[i, j]] += base[[i, k]] * base[[k, j]];
943 }
944 }
945 }
946 base = temp;
947 exp /= 2;
948 }
949 return Ok(result);
950 }
951 }
952
953 Err(LinalgError::ImplementationError(
956 "Matrix power for non-integer powers on general matrices is not yet fully implemented"
957 .to_string(),
958 ))
959}
960
961fn is_integer<F: Float>(x: F) -> bool {
963 (x - x.round()).abs() < F::from(1e-10).unwrap_or(F::epsilon())
964}