scirs2_linalg/matrix_functions/
exponential.rs

1//! Matrix exponential, logarithm, square root, and power functions
2
3use scirs2_core::ndarray::{Array2, ArrayView2};
4use scirs2_core::numeric::{Float, NumAssign, One};
5use std::iter::Sum;
6
7use crate::eigen::eig;
8use crate::error::{LinalgError, LinalgResult};
9use crate::norm::matrix_norm;
10use crate::solve::solve_multiple;
11use crate::validation::validate_decomposition;
12
13/// Compute the matrix exponential using Padé approximation.
14///
15/// The matrix exponential is defined as the power series:
16/// exp(A) = I + A + A²/2! + A³/3! + ...
17///
18/// This function uses the Padé approximation method with scaling and squaring,
19/// which is numerically stable and efficient for most matrices.
20///
21/// # Arguments
22///
23/// * `a` - Input square matrix
24/// * `workers` - Number of worker threads (None = use default)
25///
26/// # Returns
27///
28/// * Matrix exponential of a
29///
30/// # Examples
31///
32/// ```no_run
33/// use scirs2_core::ndarray::array;
34/// use scirs2_linalg::matrix_functions::expm;
35///
36/// let a = array![[0.0_f64, 1.0], [-1.0, 0.0]]; // Rotation matrix
37/// let exp_a = expm(&a.view(), None).unwrap();
38///
39/// // Expected values are approximately cos(1) and sin(1)
40/// // Exact values would be:
41/// // [[cos(1), sin(1)], [-sin(1), cos(1)]]
42/// ```
43#[allow(dead_code)]
44pub fn expm<F>(a: &ArrayView2<F>, workers: Option<usize>) -> LinalgResult<Array2<F>>
45where
46    F: Float + NumAssign + Sum + One + scirs2_core::ndarray::ScalarOperand + Send + Sync + 'static,
47{
48    use crate::parallel;
49
50    // Configure workers for parallel operations
51    parallel::configure_workers(workers);
52
53    // Parameter validation using validation helpers
54    validate_decomposition(a, "Matrix exponential computation", true)?;
55
56    let n = a.nrows();
57
58    // Special case for 1x1 matrix
59    if n == 1 {
60        let mut result = Array2::<F>::zeros((1, 1));
61        result[[0, 0]] = a[[0, 0]].exp();
62        return Ok(result);
63    }
64
65    // Special case for diagonal matrix
66    let mut is_diagonal = true;
67    for i in 0..n {
68        for j in 0..n {
69            if i != j && a[[i, j]].abs() > F::epsilon() {
70                is_diagonal = false;
71                break;
72            }
73        }
74        if !is_diagonal {
75            break;
76        }
77    }
78
79    if is_diagonal {
80        let mut result = Array2::<F>::zeros((n, n));
81        for i in 0..n {
82            result[[i, i]] = a[[i, i]].exp();
83        }
84        return Ok(result);
85    }
86
87    // Choose a suitable scaling factor and Padé order
88    let norm_a = matrix_norm(a, "1", None)?;
89    let scaling_f = norm_a.log2().ceil().max(F::zero());
90    let scaling = scaling_f.to_i32().unwrap_or(0);
91    let s = F::from(2.0_f64.powi(-scaling)).unwrap_or(F::one());
92
93    // Scale the matrix
94    let mut a_scaled = Array2::<F>::zeros((n, n));
95    for i in 0..n {
96        for j in 0..n {
97            a_scaled[[i, j]] = a[[i, j]] * s;
98        }
99    }
100
101    // Compute Padé approximation (here using order 6)
102    let c = [
103        F::from(1.0).unwrap(),
104        F::from(1.0 / 2.0).unwrap(),
105        F::from(1.0 / 6.0).unwrap(),
106        F::from(1.0 / 24.0).unwrap(),
107        F::from(1.0 / 120.0).unwrap(),
108        F::from(1.0 / 720.0).unwrap(),
109    ];
110
111    // Compute powers of A
112    let mut a2 = Array2::<F>::zeros((n, n));
113    for i in 0..n {
114        for j in 0..n {
115            for k in 0..n {
116                a2[[i, j]] += a_scaled[[i, k]] * a_scaled[[k, j]];
117            }
118        }
119    }
120
121    let mut a4 = Array2::<F>::zeros((n, n));
122    for i in 0..n {
123        for j in 0..n {
124            for k in 0..n {
125                a4[[i, j]] += a2[[i, k]] * a2[[k, j]];
126            }
127        }
128    }
129
130    // Compute the numerator of the Padé approximant: N = I + c_1*A + c_2*A^2 + ...
131    let mut n_pade = Array2::<F>::zeros((n, n));
132    for i in 0..n {
133        n_pade[[i, i]] = c[0]; // Add identity matrix * c[0]
134    }
135
136    // Add c[1] * A
137    for i in 0..n {
138        for j in 0..n {
139            n_pade[[i, j]] += c[1] * a_scaled[[i, j]];
140        }
141    }
142
143    // Add c[2] * A^2
144    for i in 0..n {
145        for j in 0..n {
146            n_pade[[i, j]] += c[2] * a2[[i, j]];
147        }
148    }
149
150    // Add c[3] * A^3
151    let mut a3 = Array2::<F>::zeros((n, n));
152    for i in 0..n {
153        for j in 0..n {
154            for k in 0..n {
155                a3[[i, j]] += a_scaled[[i, k]] * a2[[k, j]];
156            }
157        }
158    }
159
160    for i in 0..n {
161        for j in 0..n {
162            n_pade[[i, j]] += c[3] * a3[[i, j]];
163        }
164    }
165
166    // Add c[4] * A^4
167    for i in 0..n {
168        for j in 0..n {
169            n_pade[[i, j]] += c[4] * a4[[i, j]];
170        }
171    }
172
173    // Add c[5] * A^5
174    let mut a5 = Array2::<F>::zeros((n, n));
175    for i in 0..n {
176        for j in 0..n {
177            for k in 0..n {
178                a5[[i, j]] += a_scaled[[i, k]] * a4[[k, j]];
179            }
180        }
181    }
182
183    for i in 0..n {
184        for j in 0..n {
185            n_pade[[i, j]] += c[5] * a5[[i, j]];
186        }
187    }
188
189    // Compute the denominator of the Padé approximant: D = I - c_1*A + c_2*A^2 - ...
190    let mut d_pade = Array2::<F>::zeros((n, n));
191    for i in 0..n {
192        d_pade[[i, i]] = c[0]; // Add identity matrix * c[0]
193    }
194
195    // Subtract c[1] * A
196    for i in 0..n {
197        for j in 0..n {
198            d_pade[[i, j]] -= c[1] * a_scaled[[i, j]];
199        }
200    }
201
202    // Add c[2] * A^2
203    for i in 0..n {
204        for j in 0..n {
205            d_pade[[i, j]] += c[2] * a2[[i, j]];
206        }
207    }
208
209    // Subtract c[3] * A^3
210    for i in 0..n {
211        for j in 0..n {
212            d_pade[[i, j]] -= c[3] * a3[[i, j]];
213        }
214    }
215
216    // Add c[4] * A^4
217    for i in 0..n {
218        for j in 0..n {
219            d_pade[[i, j]] += c[4] * a4[[i, j]];
220        }
221    }
222
223    // Subtract c[5] * A^5
224    for i in 0..n {
225        for j in 0..n {
226            d_pade[[i, j]] -= c[5] * a5[[i, j]];
227        }
228    }
229
230    // Solve the system D*X = N for X
231    let result = solve_multiple(&d_pade.view(), &n_pade.view(), None)?;
232
233    // Undo the scaling by squaring the result s times
234    let mut exp_a = result;
235
236    for _ in 0..scaling as usize {
237        let mut temp = Array2::<F>::zeros((n, n));
238        for i in 0..n {
239            for j in 0..n {
240                for k in 0..n {
241                    temp[[i, j]] += exp_a[[i, k]] * exp_a[[k, j]];
242                }
243            }
244        }
245        exp_a = temp;
246    }
247
248    Ok(exp_a)
249}
250
251/// Compute the matrix logarithm.
252///
253/// The matrix logarithm is the inverse of the matrix exponential:
254/// if expm(B) = A, then logm(A) = B.
255///
256/// This function uses the Schur decomposition method combined with
257/// a Padé approximation for the logarithm of the triangular factor.
258///
259/// # Arguments
260///
261/// * `a` - Input square matrix (must have eigenvalues with positive real parts for real result)
262///
263/// # Returns
264///
265/// * Matrix logarithm of a
266///
267/// # Examples
268///
269/// ```
270/// use scirs2_core::ndarray::array;
271/// use scirs2_linalg::matrix_functions::logm;
272///
273/// let a = array![[1.0_f64, 0.0], [0.0, 2.0]];
274/// let log_a = logm(&a.view()).unwrap();
275/// // log_a should be approximately [[0.0, 0.0], [0.0, ln(2)]]
276/// assert!((log_a[[0, 0]]).abs() < 1e-10);
277/// assert!((log_a[[0, 1]]).abs() < 1e-10);
278/// assert!((log_a[[1, 0]]).abs() < 1e-10);
279/// assert!((log_a[[1, 1]] - 2.0_f64.ln()).abs() < 1e-10);
280/// ```
281#[allow(dead_code)]
282pub fn logm<F>(a: &ArrayView2<F>) -> LinalgResult<Array2<F>>
283where
284    F: Float + NumAssign + Sum + One + Send + Sync + scirs2_core::ndarray::ScalarOperand + 'static,
285{
286    logm_impl(a)
287}
288
289/// Internal implementation of matrix logarithm computation.
290#[allow(dead_code)]
291fn logm_impl<F>(a: &ArrayView2<F>) -> LinalgResult<Array2<F>>
292where
293    F: Float + NumAssign + Sum + One + Send + Sync + scirs2_core::ndarray::ScalarOperand + 'static,
294{
295    if a.nrows() != a.ncols() {
296        return Err(LinalgError::ShapeError(format!(
297            "Matrix must be square to compute logarithm, got shape {:?}",
298            a.shape()
299        )));
300    }
301
302    let n = a.nrows();
303
304    // Special case for 1x1 matrix
305    if n == 1 {
306        let val = a[[0, 0]];
307        if val <= F::zero() {
308            return Err(LinalgError::InvalidInputError(
309                "Cannot compute real logarithm of non-positive scalar".to_string(),
310            ));
311        }
312
313        let mut result = Array2::<F>::zeros((1, 1));
314        result[[0, 0]] = val.ln();
315        return Ok(result);
316    }
317
318    // Special case for diagonal matrix
319    let mut is_diagonal = true;
320    for i in 0..n {
321        for j in 0..n {
322            if i != j && a[[i, j]].abs() > F::epsilon() {
323                is_diagonal = false;
324                break;
325            }
326        }
327        if !is_diagonal {
328            break;
329        }
330    }
331
332    if is_diagonal {
333        // Check that all diagonal elements are positive
334        for i in 0..n {
335            if a[[i, i]] <= F::zero() {
336                return Err(LinalgError::InvalidInputError(
337                    "Cannot compute real logarithm of matrix with non-positive eigenvalues"
338                        .to_string(),
339                ));
340            }
341        }
342
343        let mut result = Array2::<F>::zeros((n, n));
344        for i in 0..n {
345            result[[i, i]] = a[[i, i]].ln();
346        }
347        return Ok(result);
348    }
349
350    // Check if the matrix is the identity
351    let mut is_identity = true;
352    for i in 0..n {
353        for j in 0..n {
354            let expected = if i == j { F::one() } else { F::zero() };
355            if (a[[i, j]] - expected).abs() > F::epsilon() {
356                is_identity = false;
357                break;
358            }
359        }
360        if !is_identity {
361            break;
362        }
363    }
364
365    // log(I) = 0
366    if is_identity {
367        return Ok(Array2::<F>::zeros((n, n)));
368    }
369
370    // Special case for 2x2 diagonal matrix
371    if n == 2 && a[[0, 1]].abs() < F::epsilon() && a[[1, 0]].abs() < F::epsilon() {
372        let a00 = a[[0, 0]];
373        let a11 = a[[1, 1]];
374
375        if a00 <= F::zero() || a11 <= F::zero() {
376            return Err(LinalgError::InvalidInputError(
377                "Cannot compute real logarithm of matrix with non-positive eigenvalues".to_string(),
378            ));
379        }
380
381        let mut result = Array2::<F>::zeros((2, 2));
382        result[[0, 0]] = a00.ln();
383        result[[1, 1]] = a11.ln();
384        return Ok(result);
385    }
386
387    // For general matrices, we use a simplified approach for matrices close to the identity
388    // This is a basic implementation that works for many cases but is not as robust as
389    // a full Schur decomposition-based implementation
390
391    // Check if the matrix is close to the identity (within a reasonable range)
392    let identity = Array2::eye(n);
393    let mut max_diff = F::zero();
394    for i in 0..n {
395        for j in 0..n {
396            let diff = (a[[i, j]] - identity[[i, j]]).abs();
397            if diff > max_diff {
398                max_diff = diff;
399            }
400        }
401    }
402
403    // If the matrix is too far from identity, try an inverse scaling and squaring approach
404    if max_diff > F::from(0.5).unwrap() {
405        // For matrices not close to identity, we use inverse scaling and squaring
406        // This approach works by finding a scaling factor k such that A^(1/2^k) is close to I
407        // then computing log(A) = 2^k * log(A^(1/2^k))
408
409        // Find an appropriate scaling factor
410        let mut scaling_k = 0;
411        let mut a_scaled = a.to_owned();
412
413        // Try to find a scaling where the matrix becomes closer to identity
414        // We'll use matrix square root iterations to get A^(1/2^k)
415        while scaling_k < 10 {
416            // Limit iterations to avoid infinite loops
417            let mut max_scaled_diff = F::zero();
418            for i in 0..n {
419                for j in 0..n {
420                    let expected = if i == j { F::one() } else { F::zero() };
421                    let diff = (a_scaled[[i, j]] - expected).abs();
422                    if diff > max_scaled_diff {
423                        max_scaled_diff = diff;
424                    }
425                }
426            }
427
428            if max_scaled_diff <= F::from(0.2).unwrap() {
429                break;
430            }
431
432            // Compute matrix square root using our sqrtm function
433            match sqrtm(&a_scaled.view(), 20, F::from(1e-12).unwrap()) {
434                Ok(sqrt_result) => {
435                    a_scaled = sqrt_result;
436                    scaling_k += 1;
437                }
438                Err(_) => {
439                    return Err(LinalgError::ImplementationError(
440                        "Matrix logarithm: Could not compute matrix square root for scaling"
441                            .to_string(),
442                    ));
443                }
444            }
445        }
446
447        if scaling_k >= 10 {
448            return Err(LinalgError::ImplementationError(
449                "Matrix logarithm: Matrix could not be scaled close enough to identity".to_string(),
450            ));
451        }
452
453        // Now compute log(A^(1/2^k)) using the series
454        let mut x_scaled = Array2::<F>::zeros((n, n));
455        for i in 0..n {
456            for j in 0..n {
457                let expected = if i == j { F::one() } else { F::zero() };
458                x_scaled[[i, j]] = a_scaled[[i, j]] - expected;
459            }
460        }
461
462        // Compute powers of X for the series (use more terms for better accuracy)
463        let mut x2 = Array2::<F>::zeros((n, n));
464        for i in 0..n {
465            for j in 0..n {
466                for k in 0..n {
467                    x2[[i, j]] += x_scaled[[i, k]] * x_scaled[[k, j]];
468                }
469            }
470        }
471
472        let mut x3 = Array2::<F>::zeros((n, n));
473        for i in 0..n {
474            for j in 0..n {
475                for k in 0..n {
476                    x3[[i, j]] += x2[[i, k]] * x_scaled[[k, j]];
477                }
478            }
479        }
480
481        let mut x4 = Array2::<F>::zeros((n, n));
482        for i in 0..n {
483            for j in 0..n {
484                for k in 0..n {
485                    x4[[i, j]] += x3[[i, k]] * x_scaled[[k, j]];
486                }
487            }
488        }
489
490        let mut x5 = Array2::<F>::zeros((n, n));
491        for i in 0..n {
492            for j in 0..n {
493                for k in 0..n {
494                    x5[[i, j]] += x4[[i, k]] * x_scaled[[k, j]];
495                }
496            }
497        }
498
499        let mut x6 = Array2::<F>::zeros((n, n));
500        for i in 0..n {
501            for j in 0..n {
502                for k in 0..n {
503                    x6[[i, j]] += x5[[i, k]] * x_scaled[[k, j]];
504                }
505            }
506        }
507
508        // Compute log(A^(1/2^k)) using the series with more terms
509        // log(1 + X) = X - X²/2 + X³/3 - X⁴/4 + X⁵/5 - X⁶/6 + ...
510        let mut log_scaled = Array2::<F>::zeros((n, n));
511        let half = F::from(0.5).unwrap();
512        let third = F::from(1.0 / 3.0).unwrap();
513        let fourth = F::from(0.25).unwrap();
514        let fifth = F::from(0.2).unwrap();
515        let sixth = F::from(1.0 / 6.0).unwrap();
516
517        for i in 0..n {
518            for j in 0..n {
519                log_scaled[[i, j]] = x_scaled[[i, j]] - half * x2[[i, j]] + third * x3[[i, j]]
520                    - fourth * x4[[i, j]]
521                    + fifth * x5[[i, j]]
522                    - sixth * x6[[i, j]];
523            }
524        }
525
526        // Scale back: log(A) = 2^k * log(A^(1/2^k))
527        let scale_factor = F::from(2.0_f64.powi(scaling_k)).unwrap();
528        for i in 0..n {
529            for j in 0..n {
530                log_scaled[[i, j]] *= scale_factor;
531            }
532        }
533
534        return Ok(log_scaled);
535    }
536
537    // For matrices close to I, we can use the series: log(I + X) = X - X²/2 + X³/3 - X⁴/4 + ...
538    // where X = A - I
539    let mut x = Array2::<F>::zeros((n, n));
540    for i in 0..n {
541        for j in 0..n {
542            x[[i, j]] = a[[i, j]] - identity[[i, j]];
543        }
544    }
545
546    // Compute X^2, X^3, X^4, X^5, X^6 for the series
547    let mut x2 = Array2::<F>::zeros((n, n));
548    for i in 0..n {
549        for j in 0..n {
550            for k in 0..n {
551                x2[[i, j]] += x[[i, k]] * x[[k, j]];
552            }
553        }
554    }
555
556    let mut x3 = Array2::<F>::zeros((n, n));
557    for i in 0..n {
558        for j in 0..n {
559            for k in 0..n {
560                x3[[i, j]] += x2[[i, k]] * x[[k, j]];
561            }
562        }
563    }
564
565    let mut x4 = Array2::<F>::zeros((n, n));
566    for i in 0..n {
567        for j in 0..n {
568            for k in 0..n {
569                x4[[i, j]] += x3[[i, k]] * x[[k, j]];
570            }
571        }
572    }
573
574    let mut x5 = Array2::<F>::zeros((n, n));
575    for i in 0..n {
576        for j in 0..n {
577            for k in 0..n {
578                x5[[i, j]] += x4[[i, k]] * x[[k, j]];
579            }
580        }
581    }
582
583    let mut x6 = Array2::<F>::zeros((n, n));
584    for i in 0..n {
585        for j in 0..n {
586            for k in 0..n {
587                x6[[i, j]] += x5[[i, k]] * x[[k, j]];
588            }
589        }
590    }
591
592    // Compute log(A) using the series log(I + X) = X - X²/2 + X³/3 - X⁴/4 + X⁵/5 - X⁶/6 + ...
593    let mut result = Array2::<F>::zeros((n, n));
594    let half = F::from(0.5).unwrap();
595    let third = F::from(1.0 / 3.0).unwrap();
596    let fourth = F::from(0.25).unwrap();
597    let fifth = F::from(0.2).unwrap();
598    let sixth = F::from(1.0 / 6.0).unwrap();
599
600    for i in 0..n {
601        for j in 0..n {
602            result[[i, j]] = x[[i, j]] - half * x2[[i, j]] + third * x3[[i, j]]
603                - fourth * x4[[i, j]]
604                + fifth * x5[[i, j]]
605                - sixth * x6[[i, j]];
606        }
607    }
608
609    Ok(result)
610}
611
612/// Compute the matrix logarithm with parallel processing support.
613///
614/// This function computes log(A) for a square matrix A using the scaling and squaring method
615/// combined with Taylor series expansion. The computation is accelerated using parallel
616/// processing for matrix multiplications and element-wise operations.
617///
618/// # Arguments
619///
620/// * `a` - Input square matrix
621/// * `workers` - Number of worker threads (None = use default)
622///
623/// # Returns
624///
625/// * Matrix logarithm of the input
626///
627/// # Examples
628///
629/// ```
630/// use scirs2_core::ndarray::array;
631/// use scirs2_linalg::matrix_functions::logm_parallel;
632///
633/// let a = array![[1.0_f64, 0.0], [0.0, 2.0]];
634/// let log_a = logm_parallel(&a.view(), Some(4)).unwrap();
635/// assert!((log_a[[0, 0]]).abs() < 1e-10);
636/// assert!((log_a[[1, 1]] - 2.0_f64.ln()).abs() < 1e-10);
637/// ```
638#[allow(dead_code)]
639pub fn logm_parallel<F>(a: &ArrayView2<F>, workers: Option<usize>) -> LinalgResult<Array2<F>>
640where
641    F: Float + NumAssign + Sum + One + Send + Sync + scirs2_core::ndarray::ScalarOperand + 'static,
642{
643    use crate::parallel;
644
645    // Configure workers for parallel operations
646    parallel::configure_workers(workers);
647
648    // Use threshold to determine if parallel processing is worthwhile
649    const PARALLEL_THRESHOLD: usize = 50; // For matrices larger than 50x50
650
651    if a.nrows() < PARALLEL_THRESHOLD || a.ncols() < PARALLEL_THRESHOLD {
652        // For small matrices, use sequential implementation
653        return logm(a);
654    }
655
656    // For larger matrices, use the same algorithm but with parallel matrix operations
657    logm_impl_parallel(a)
658}
659
660/// Internal implementation of parallel matrix logarithm computation.
661#[allow(dead_code)]
662fn logm_impl_parallel<F>(a: &ArrayView2<F>) -> LinalgResult<Array2<F>>
663where
664    F: Float + NumAssign + Sum + One + Send + Sync + scirs2_core::ndarray::ScalarOperand + 'static,
665{
666    // For now, use the sequential implementation
667    // TODO: Implement parallel version using scirs2_core::parallel_ops
668    logm_impl(a)
669}
670
671/// Compute the matrix square root using the Denman-Beavers iteration.
672///
673/// The matrix square root X of matrix A satisfies X^2 = A.
674/// This function uses the Denman-Beavers iteration, which is suitable
675/// for matrices with no eigenvalues on the negative real axis.
676///
677/// # Arguments
678///
679/// * `a` - Input square matrix (should be positive definite for real result)
680/// * `max_iter` - Maximum number of iterations
681/// * `tol` - Convergence tolerance
682///
683/// # Returns
684///
685/// * Matrix square root of a
686///
687/// # Examples
688///
689/// ```
690/// use scirs2_core::ndarray::array;
691/// use scirs2_linalg::matrix_functions::sqrtm;
692///
693/// let a = array![[4.0_f64, 0.0], [0.0, 9.0]];
694/// let sqrt_a = sqrtm(&a.view(), 20, 1e-10).unwrap();
695/// // sqrt_a should be approximately [[2.0, 0.0], [0.0, 3.0]]
696/// assert!((sqrt_a[[0, 0]] - 2.0).abs() < 1e-10);
697/// assert!((sqrt_a[[0, 1]] - 0.0).abs() < 1e-10);
698/// assert!((sqrt_a[[1, 0]] - 0.0).abs() < 1e-10);
699/// assert!((sqrt_a[[1, 1]] - 3.0).abs() < 1e-10);
700/// ```
701#[allow(dead_code)]
702pub fn sqrtm<F>(a: &ArrayView2<F>, maxiter: usize, tol: F) -> LinalgResult<Array2<F>>
703where
704    F: Float + NumAssign + Sum + One + Send + Sync + scirs2_core::ndarray::ScalarOperand + 'static,
705{
706    sqrtm_impl(a, maxiter, tol)
707}
708
709/// Internal implementation of matrix square root computation.
710#[allow(dead_code)]
711fn sqrtm_impl<F>(a: &ArrayView2<F>, maxiter: usize, tol: F) -> LinalgResult<Array2<F>>
712where
713    F: Float + NumAssign + Sum + One + Send + Sync + scirs2_core::ndarray::ScalarOperand + 'static,
714{
715    validate_decomposition(a, "Matrix square root computation", true)?;
716
717    let n = a.nrows();
718
719    // Special case for 1x1 matrix
720    if n == 1 {
721        let val = a[[0, 0]];
722        if val < F::zero() {
723            return Err(LinalgError::InvalidInputError(
724                "Cannot compute real square root of negative number".to_string(),
725            ));
726        }
727        let mut result = Array2::<F>::zeros((1, 1));
728        result[[0, 0]] = val.sqrt();
729        return Ok(result);
730    }
731
732    // Special case for diagonal matrix
733    let mut is_diagonal = true;
734    for i in 0..n {
735        for j in 0..n {
736            if i != j && a[[i, j]].abs() > F::epsilon() {
737                is_diagonal = false;
738                break;
739            }
740        }
741        if !is_diagonal {
742            break;
743        }
744    }
745
746    if is_diagonal {
747        let mut result = Array2::<F>::zeros((n, n));
748        for i in 0..n {
749            if a[[i, i]] < F::zero() {
750                return Err(LinalgError::InvalidInputError(
751                    "Cannot compute real square root of matrix with negative eigenvalues"
752                        .to_string(),
753                ));
754            }
755            result[[i, i]] = a[[i, i]].sqrt();
756        }
757        return Ok(result);
758    }
759
760    // Use Denman-Beavers iteration for general matrices
761    let mut x = a.to_owned();
762    let mut y = Array2::eye(n);
763
764    for _ in 0..maxiter {
765        // Store previous iteration for convergence check
766        let x_prev = x.clone();
767
768        // Compute X_{k+1} = (X_k + Y_k^{-1}) / 2
769        // and Y_{k+1} = (Y_k + X_k^{-1}) / 2
770
771        // For simplicity, we'll use a basic implementation
772        // In practice, you'd want more sophisticated inversion
773        let y_inv = solve_multiple(&y.view(), &Array2::eye(n).view(), None)?;
774        let x_inv = solve_multiple(&x.view(), &Array2::eye(n).view(), None)?;
775
776        // Update X and Y
777        let mut x_new = Array2::<F>::zeros((n, n));
778        let mut y_new = Array2::<F>::zeros((n, n));
779
780        for i in 0..n {
781            for j in 0..n {
782                x_new[[i, j]] = (x[[i, j]] + y_inv[[i, j]]) * F::from(0.5).unwrap();
783                y_new[[i, j]] = (y[[i, j]] + x_inv[[i, j]]) * F::from(0.5).unwrap();
784            }
785        }
786
787        x = x_new;
788        y = y_new;
789
790        // Check convergence
791        let mut max_diff = F::zero();
792        for i in 0..n {
793            for j in 0..n {
794                let diff = (x[[i, j]] - x_prev[[i, j]]).abs();
795                if diff > max_diff {
796                    max_diff = diff;
797                }
798            }
799        }
800
801        if max_diff < tol {
802            break;
803        }
804    }
805
806    Ok(x)
807}
808
809/// Compute the matrix square root with parallel processing support.
810///
811/// # Arguments
812///
813/// * `a` - Input square matrix
814/// * `maxiter` - Maximum number of iterations
815/// * `tol` - Convergence tolerance
816/// * `workers` - Number of worker threads (None = use default)
817///
818/// # Returns
819///
820/// * Matrix square root of a
821#[allow(dead_code)]
822pub fn sqrtm_parallel<F>(
823    a: &ArrayView2<F>,
824    maxiter: usize,
825    tol: F,
826    workers: Option<usize>,
827) -> LinalgResult<Array2<F>>
828where
829    F: Float + NumAssign + Sum + One + Send + Sync + scirs2_core::ndarray::ScalarOperand + 'static,
830{
831    use crate::parallel;
832
833    // Configure workers for parallel operations
834    parallel::configure_workers(workers);
835
836    // For small matrices, use sequential version
837    const PARALLEL_THRESHOLD: usize = 50;
838    if a.nrows() < PARALLEL_THRESHOLD {
839        return sqrtm(a, maxiter, tol);
840    }
841
842    // For now, delegate to sequential implementation
843    // TODO: Implement parallel version
844    sqrtm_impl(a, maxiter, tol)
845}
846
847/// Compute the matrix power A^p for a real number p.
848///
849/// # Arguments
850///
851/// * `a` - Input square matrix
852/// * `p` - Power (real number)
853///
854/// # Returns
855///
856/// * Matrix power A^p
857///
858/// # Examples
859///
860/// ```
861/// use scirs2_core::ndarray::array;
862/// use scirs2_linalg::matrix_functions::matrix_power;
863///
864/// let a = array![[4.0_f64, 0.0], [0.0, 9.0]];
865/// let a_half = matrix_power(&a.view(), 0.5).unwrap();
866/// // a_half should be approximately [[2.0, 0.0], [0.0, 3.0]]
867/// ```
868#[allow(dead_code)]
869pub fn matrix_power<F>(a: &ArrayView2<F>, p: F) -> LinalgResult<Array2<F>>
870where
871    F: Float + NumAssign + Sum + One + 'static + Send + Sync + scirs2_core::ndarray::ScalarOperand,
872{
873    validate_decomposition(a, "Matrix power computation", true)?;
874
875    let n = a.nrows();
876
877    // Special case for p = 0 (returns identity)
878    if p.abs() < F::epsilon() {
879        return Ok(Array2::eye(n));
880    }
881
882    // Special case for p = 1 (returns the matrix itself)
883    if (p - F::one()).abs() < F::epsilon() {
884        return Ok(a.to_owned());
885    }
886
887    // Special case for diagonal matrix
888    let mut is_diagonal = true;
889    for i in 0..n {
890        for j in 0..n {
891            if i != j && a[[i, j]].abs() > F::epsilon() {
892                is_diagonal = false;
893                break;
894            }
895        }
896        if !is_diagonal {
897            break;
898        }
899    }
900
901    if is_diagonal {
902        let mut result = Array2::<F>::zeros((n, n));
903        for i in 0..n {
904            let val = a[[i, i]];
905            if val < F::zero() && !is_integer(p) {
906                return Err(LinalgError::InvalidInputError(
907                    "Cannot compute real fractional power of negative number".to_string(),
908                ));
909            }
910            result[[i, i]] = val.powf(p);
911        }
912        return Ok(result);
913    }
914
915    // Special case for integer powers
916    if is_integer(p) {
917        let int_p = p.to_i32().unwrap_or(0);
918        if int_p >= 0 {
919            // Positive integer power - use repeated squaring
920            let mut result = Array2::eye(n);
921            let mut base = a.to_owned();
922            let mut exp = int_p as u32;
923
924            while exp > 0 {
925                if exp % 2 == 1 {
926                    // Multiply result by base
927                    let mut temp = Array2::<F>::zeros((n, n));
928                    for i in 0..n {
929                        for j in 0..n {
930                            for k in 0..n {
931                                temp[[i, j]] += result[[i, k]] * base[[k, j]];
932                            }
933                        }
934                    }
935                    result = temp;
936                }
937                // Square the base
938                let mut temp = Array2::<F>::zeros((n, n));
939                for i in 0..n {
940                    for j in 0..n {
941                        for k in 0..n {
942                            temp[[i, j]] += base[[i, k]] * base[[k, j]];
943                        }
944                    }
945                }
946                base = temp;
947                exp /= 2;
948            }
949            return Ok(result);
950        }
951    }
952
953    // For non-integer powers on general matrices, return a simplified error for now
954    // A full implementation would require complex eigenvalue handling
955    Err(LinalgError::ImplementationError(
956        "Matrix power for non-integer powers on general matrices is not yet fully implemented"
957            .to_string(),
958    ))
959}
960
961/// Helper function to check if a floating point number is close to an integer
962fn is_integer<F: Float>(x: F) -> bool {
963    (x - x.round()).abs() < F::from(1e-10).unwrap_or(F::epsilon())
964}