Skip to main content

scirs2_linalg/
compat.rs

1//! Backward compatibility layer for ndarray-linalg-style trait-based API
2//!
3//! This module provides trait-based extensions to `ArrayBase` types that mirror
4//! the old `ndarray-linalg` API, making it easier to migrate existing code.
5//!
6//! # Example
7//!
8//! ```rust
9//! use scirs2_core::ndarray::array;
10//! use scirs2_linalg::compat::ArrayLinalgExt;
11//!
12//! let a = array![[1.0_f64, 2.0], [3.0, 4.0]];
13//!
14//! // Old ndarray-linalg style (now works via compat layer)
15//! let (u, s, vt) = a.svd(true).expect("valid input");
16//! let inv_a = a.inv().expect("valid input");
17//! ```
18
19// ✅ SciRS2 POLICY: Use scirs2_core for all external dependencies
20use crate::{LinalgError, LinalgResult};
21use scirs2_core::ndarray::{Array1, Array2, ArrayBase, Data, Ix1, Ix2, ScalarOperand};
22use scirs2_core::numeric::{Float, NumAssign, Zero};
23use scirs2_core::Complex;
24use std::iter::Sum;
25
26/// UPLO parameter for symmetric/Hermitian matrices
27#[derive(Debug, Clone, Copy, PartialEq, Eq)]
28pub enum UPLO {
29    /// Upper triangular
30    Upper,
31    /// Lower triangular
32    Lower,
33}
34
35/// Trait providing ndarray-linalg-compatible linear algebra operations
36pub trait ArrayLinalgExt<A, S: scirs2_core::ndarray::RawData> {
37    /// Singular Value Decomposition: A = U Σ Vᵀ
38    fn svd(&self, compute_uv: bool) -> LinalgResult<(Array2<A>, Array1<A>, Array2<A>)>;
39
40    /// Eigenvalues and eigenvectors of a general matrix
41    #[allow(clippy::type_complexity)]
42    fn eig(
43        &self,
44    ) -> LinalgResult<(
45        Array1<scirs2_core::Complex<A>>,
46        Array2<scirs2_core::Complex<A>>,
47    )>;
48
49    /// Eigenvalues and eigenvectors of a symmetric/Hermitian matrix
50    fn eigh(&self, uplo: UPLO) -> LinalgResult<(Array1<A>, Array2<A>)>;
51
52    /// Eigenvalues only of a symmetric/Hermitian matrix
53    fn eigvalsh(&self, uplo: UPLO) -> LinalgResult<Array1<A>>;
54
55    /// Matrix inverse
56    fn inv(&self) -> LinalgResult<Array2<A>>;
57
58    /// Solve linear system Ax = b for vector b
59    fn solve(&self, b: &ArrayBase<S, Ix1>) -> LinalgResult<Array1<A>>;
60
61    /// Solve linear system Ax = B for matrix B
62    fn solve_into(&self, b: &ArrayBase<S, Ix2>) -> LinalgResult<Array2<A>>;
63
64    /// L2 norm (Euclidean norm)
65    fn norm_l2(&self) -> A;
66
67    /// Frobenius norm
68    fn norm_fro(&self) -> A;
69
70    /// Determinant
71    fn det(&self) -> LinalgResult<A>;
72
73    /// QR decomposition: A = QR
74    fn qr(&self) -> LinalgResult<(Array2<A>, Array2<A>)>;
75
76    /// LU decomposition: PA = LU
77    fn lu(&self) -> LinalgResult<(Array2<A>, Array2<A>, Array2<A>)>;
78
79    /// Cholesky decomposition: A = LLᵀ
80    fn cholesky(&self) -> LinalgResult<Array2<A>>;
81}
82
83impl<A, S> ArrayLinalgExt<A, S> for ArrayBase<S, Ix2>
84where
85    A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
86    S: Data<Elem = A>,
87{
88    fn svd(&self, compute_uv: bool) -> LinalgResult<(Array2<A>, Array1<A>, Array2<A>)> {
89        crate::svd(&self.view(), compute_uv, None)
90    }
91
92    fn eig(
93        &self,
94    ) -> LinalgResult<(
95        Array1<scirs2_core::Complex<A>>,
96        Array2<scirs2_core::Complex<A>>,
97    )> {
98        crate::eig(&self.view(), None)
99    }
100
101    fn eigh(&self, _uplo: UPLO) -> LinalgResult<(Array1<A>, Array2<A>)> {
102        // scirs2-linalg's eigh doesn't use UPLO parameter currently
103        crate::eigh(&self.view(), None)
104    }
105
106    fn eigvalsh(&self, _uplo: UPLO) -> LinalgResult<Array1<A>> {
107        crate::eigvalsh(&self.view(), None)
108    }
109
110    fn inv(&self) -> LinalgResult<Array2<A>> {
111        crate::inv(&self.view(), None)
112    }
113
114    fn solve(&self, b: &ArrayBase<S, Ix1>) -> LinalgResult<Array1<A>> {
115        crate::solve(&self.view(), &b.view(), None)
116    }
117
118    fn solve_into(&self, b: &ArrayBase<S, Ix2>) -> LinalgResult<Array2<A>> {
119        crate::solve_multiple(&self.view(), &b.view(), None)
120    }
121
122    fn norm_l2(&self) -> A {
123        // Calculate Frobenius norm for matrices (equivalent to L2 for flattened matrix)
124        self.iter().map(|&x| x * x).sum::<A>().sqrt()
125    }
126
127    fn norm_fro(&self) -> A {
128        self.iter().map(|&x| x * x).sum::<A>().sqrt()
129    }
130
131    fn det(&self) -> LinalgResult<A> {
132        crate::det(&self.view(), None)
133    }
134
135    fn qr(&self) -> LinalgResult<(Array2<A>, Array2<A>)> {
136        crate::qr(&self.view(), None)
137    }
138
139    fn lu(&self) -> LinalgResult<(Array2<A>, Array2<A>, Array2<A>)> {
140        crate::lu(&self.view(), None)
141    }
142
143    fn cholesky(&self) -> LinalgResult<Array2<A>> {
144        crate::cholesky(&self.view(), None)
145    }
146}
147
148/// Trait for solving linear systems (compatibility)
149pub trait Solve<A> {
150    /// Output type for the solution
151    type Output;
152
153    /// Solve the linear system
154    fn solve(&self, rhs: &Self) -> LinalgResult<Self::Output>;
155}
156
157/// Trait for computing singular value decomposition
158pub trait SVD {
159    /// Singular values type
160    type S;
161    /// Left singular vectors type
162    type U;
163    /// Right singular vectors type (transposed)
164    type Vt;
165
166    /// Compute SVD
167    fn svd(&self, compute_uv: bool) -> LinalgResult<(Self::U, Self::S, Self::Vt)>;
168}
169
170impl<A, S> SVD for ArrayBase<S, Ix2>
171where
172    A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
173    S: Data<Elem = A>,
174{
175    type S = Array1<A>;
176    type U = Array2<A>;
177    type Vt = Array2<A>;
178
179    fn svd(&self, compute_uv: bool) -> LinalgResult<(Self::U, Self::S, Self::Vt)> {
180        ArrayLinalgExt::svd(self, compute_uv)
181    }
182}
183
184/// Trait for computing eigenvalues and eigenvectors
185pub trait Eig {
186    /// Eigenvalue type
187    type EigVal;
188    /// Eigenvector type
189    type EigVec;
190
191    /// Compute eigenvalues and eigenvectors
192    fn eig(&self) -> LinalgResult<(Self::EigVal, Self::EigVec)>;
193}
194
195impl<A, S> Eig for ArrayBase<S, Ix2>
196where
197    A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
198    S: Data<Elem = A>,
199{
200    type EigVal = Array1<scirs2_core::Complex<A>>;
201    type EigVec = Array2<scirs2_core::Complex<A>>;
202
203    fn eig(&self) -> LinalgResult<(Self::EigVal, Self::EigVec)> {
204        ArrayLinalgExt::eig(self)
205    }
206}
207
208/// Trait for computing eigenvalues and eigenvectors of Hermitian matrices
209pub trait Eigh {
210    /// Eigenvalue type
211    type EigVal;
212    /// Eigenvector type
213    type EigVec;
214
215    /// Compute eigenvalues and eigenvectors
216    fn eigh(&self, uplo: UPLO) -> LinalgResult<(Self::EigVal, Self::EigVec)>;
217}
218
219impl<A, S> Eigh for ArrayBase<S, Ix2>
220where
221    A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
222    S: Data<Elem = A>,
223{
224    type EigVal = Array1<A>;
225    type EigVec = Array2<A>;
226
227    fn eigh(&self, uplo: UPLO) -> LinalgResult<(Self::EigVal, Self::EigVec)> {
228        ArrayLinalgExt::eigh(self, uplo)
229    }
230}
231
232/// Trait for computing matrix inverse
233pub trait Inverse {
234    /// Output type
235    type Output;
236
237    /// Compute matrix inverse
238    fn inv(&self) -> LinalgResult<Self::Output>;
239}
240
241impl<A, S> Inverse for ArrayBase<S, Ix2>
242where
243    A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
244    S: Data<Elem = A>,
245{
246    type Output = Array2<A>;
247
248    fn inv(&self) -> LinalgResult<Self::Output> {
249        ArrayLinalgExt::inv(self)
250    }
251}
252
253/// Trait for computing matrix norms
254pub trait Norm<A> {
255    /// Compute matrix norm
256    fn norm(&self) -> A;
257
258    /// Compute L2 norm
259    fn norm_l2(&self) -> A;
260
261    /// Compute Frobenius norm
262    fn norm_fro(&self) -> A;
263}
264
265impl<A, S> Norm<A> for ArrayBase<S, Ix2>
266where
267    A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
268    S: Data<Elem = A>,
269{
270    fn norm(&self) -> A {
271        ArrayLinalgExt::norm_fro(self)
272    }
273
274    fn norm_l2(&self) -> A {
275        ArrayLinalgExt::norm_l2(self)
276    }
277
278    fn norm_fro(&self) -> A {
279        ArrayLinalgExt::norm_fro(self)
280    }
281}
282
283impl<A, S> Norm<A> for ArrayBase<S, Ix1>
284where
285    A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
286    S: Data<Elem = A>,
287{
288    fn norm(&self) -> A {
289        self.norm_l2()
290    }
291
292    fn norm_l2(&self) -> A {
293        self.iter().map(|&x| x * x).sum::<A>().sqrt()
294    }
295
296    fn norm_fro(&self) -> A {
297        self.norm_l2()
298    }
299}
300
301// ============================================================================
302// Standalone wrapper functions for scipy.linalg compatibility
303// ============================================================================
304// These functions provide a scipy.linalg-style API by wrapping the internal
305// implementation functions. They are designed to be used via scipy_compat module.
306
307/// Type alias for SVD result (U, S, Vt)
308pub type SvdResult<A> = (Array2<A>, Array1<A>, Array2<A>);
309
310/// Singular Value Decomposition: A = U Σ Vᵀ
311pub fn svd<A, S>(a: &ArrayBase<S, Ix2>, compute_uv: bool) -> LinalgResult<SvdResult<A>>
312where
313    A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
314    S: Data<Elem = A>,
315{
316    crate::svd(&a.view(), compute_uv, None)
317}
318
319/// Compute eigenvalues and eigenvectors of a general matrix
320#[allow(clippy::type_complexity)]
321pub fn eig<A, S>(
322    a: &ArrayBase<S, Ix2>,
323) -> LinalgResult<(
324    Array1<scirs2_core::Complex<A>>,
325    Array2<scirs2_core::Complex<A>>,
326)>
327where
328    A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
329    S: Data<Elem = A>,
330{
331    crate::eig(&a.view(), None)
332}
333
334/// Compute eigenvalues and eigenvectors of a symmetric/Hermitian matrix
335pub fn eigh<A, S>(a: &ArrayBase<S, Ix2>, uplo: UPLO) -> LinalgResult<(Array1<A>, Array2<A>)>
336where
337    A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
338    S: Data<Elem = A>,
339{
340    let _ = uplo; // Currently unused
341    crate::eigh(&a.view(), None)
342}
343
344/// Compute eigenvalues only of a symmetric/Hermitian matrix
345pub fn eigvalsh<A, S>(a: &ArrayBase<S, Ix2>, uplo: UPLO) -> LinalgResult<Array1<A>>
346where
347    A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
348    S: Data<Elem = A>,
349{
350    let _ = uplo; // Currently unused
351    crate::eigvalsh(&a.view(), None)
352}
353
354/// Compute eigenvalues only of a general matrix
355pub fn eigvals<A, S>(a: &ArrayBase<S, Ix2>) -> LinalgResult<Array1<scirs2_core::Complex<A>>>
356where
357    A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
358    S: Data<Elem = A>,
359{
360    let (vals, _) = crate::eig(&a.view(), None)?;
361    Ok(vals)
362}
363
364/// Compute eigenvalues of a banded symmetric matrix
365pub fn eigvals_banded<A, S>(a: &ArrayBase<S, Ix2>, uplo: UPLO) -> LinalgResult<Array1<A>>
366where
367    A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
368    S: Data<Elem = A>,
369{
370    eigvalsh(a, uplo)
371}
372
373/// Compute eigenvalues and eigenvectors of a banded symmetric matrix
374pub fn eig_banded<A, S>(a: &ArrayBase<S, Ix2>, uplo: UPLO) -> LinalgResult<(Array1<A>, Array2<A>)>
375where
376    A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
377    S: Data<Elem = A>,
378{
379    eigh(a, uplo)
380}
381
382/// Compute eigenvalues and eigenvectors of a tridiagonal symmetric matrix
383pub fn eigh_tridiagonal<A>(d: &Array1<A>, e: &Array1<A>) -> LinalgResult<(Array1<A>, Array2<A>)>
384where
385    A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
386{
387    // Construct tridiagonal matrix and use eigh
388    let n = d.len();
389    let mut mat = Array2::zeros((n, n));
390    for i in 0..n {
391        mat[[i, i]] = d[i];
392        if i < n - 1 {
393            mat[[i, i + 1]] = e[i];
394            mat[[i + 1, i]] = e[i];
395        }
396    }
397    eigh(&mat, UPLO::Lower)
398}
399
400/// Compute eigenvalues only of a tridiagonal symmetric matrix
401pub fn eigvalsh_tridiagonal<A>(d: &Array1<A>, e: &Array1<A>) -> LinalgResult<Array1<A>>
402where
403    A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
404{
405    let (vals, _) = eigh_tridiagonal(d, e)?;
406    Ok(vals)
407}
408
409/// Matrix inverse
410pub fn inv<A, S>(a: &ArrayBase<S, Ix2>) -> LinalgResult<Array2<A>>
411where
412    A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
413    S: Data<Elem = A>,
414{
415    crate::inv(&a.view(), None)
416}
417
418/// Determinant of a square matrix
419pub fn det<A, S>(a: &ArrayBase<S, Ix2>) -> LinalgResult<A>
420where
421    A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
422    S: Data<Elem = A>,
423{
424    crate::det(&a.view(), None)
425}
426
427/// QR decomposition: A = QR
428pub fn qr<A, S>(a: &ArrayBase<S, Ix2>) -> LinalgResult<(Array2<A>, Array2<A>)>
429where
430    A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
431    S: Data<Elem = A>,
432{
433    crate::qr(&a.view(), None)
434}
435
436/// RQ decomposition: A = RQ
437pub fn rq<A, S>(a: &ArrayBase<S, Ix2>) -> LinalgResult<(Array2<A>, Array2<A>)>
438where
439    A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
440    S: Data<Elem = A>,
441{
442    // RQ = reverse of QR on transposed matrix
443    let t = a.t();
444    let (q, r) = crate::qr(&t.view(), None)?;
445    Ok((r.reversed_axes(), q.reversed_axes()))
446}
447
448/// LU decomposition: PA = LU
449pub fn lu<A, S>(a: &ArrayBase<S, Ix2>) -> LinalgResult<(Array2<A>, Array2<A>, Array2<A>)>
450where
451    A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
452    S: Data<Elem = A>,
453{
454    crate::lu(&a.view(), None)
455}
456
457/// Cholesky decomposition: A = LLᵀ
458pub fn cholesky<A, S>(a: &ArrayBase<S, Ix2>, uplo: UPLO) -> LinalgResult<Array2<A>>
459where
460    A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
461    S: Data<Elem = A>,
462{
463    let _ = uplo; // Currently unused
464    crate::cholesky(&a.view(), None)
465}
466
467/// Solve linear system Ax = b
468pub fn compat_solve<A, S1, S2>(
469    a: &ArrayBase<S1, Ix2>,
470    b: &ArrayBase<S2, Ix1>,
471) -> LinalgResult<Array1<A>>
472where
473    A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
474    S1: Data<Elem = A>,
475    S2: Data<Elem = A>,
476{
477    crate::solve(&a.view(), &b.view(), None)
478}
479
480/// Solve banded linear system
481pub fn solve_banded<A, S1, S2>(
482    l_and_u: (usize, usize),
483    ab: &ArrayBase<S1, Ix2>,
484    b: &ArrayBase<S2, Ix1>,
485) -> LinalgResult<Array1<A>>
486where
487    A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
488    S1: Data<Elem = A>,
489    S2: Data<Elem = A>,
490{
491    let (l, u) = l_and_u;
492    crate::structured_solvers::solve_banded(l, u, &ab.view(), &b.view())
493}
494
495/// Solve triangular system
496pub fn solve_triangular<A, S1, S2>(
497    a: &ArrayBase<S1, Ix2>,
498    b: &ArrayBase<S2, Ix1>,
499    lower: bool,
500) -> LinalgResult<Array1<A>>
501where
502    A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
503    S1: Data<Elem = A>,
504    S2: Data<Elem = A>,
505{
506    let _ = (a, b, lower);
507    Err(LinalgError::ComputationError(
508        "solve_triangular not yet implemented".to_string(),
509    ))
510}
511
512/// Least squares solution
513pub fn lstsq<A, S1, S2>(a: &ArrayBase<S1, Ix2>, b: &ArrayBase<S2, Ix1>) -> LinalgResult<Array1<A>>
514where
515    A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
516    S1: Data<Elem = A>,
517    S2: Data<Elem = A>,
518{
519    let result = crate::lstsq(&a.view(), &b.view(), None)?;
520    Ok(result.x)
521}
522
523/// Pseudo-inverse (Moore-Penrose inverse)
524pub fn pinv<A, S>(a: &ArrayBase<S, Ix2>) -> LinalgResult<Array2<A>>
525where
526    A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
527    S: Data<Elem = A>,
528{
529    // Use SVD-based pseudo-inverse: A† = V Σ† Uᵀ
530    let (u, s, vt) = crate::svd(&a.view(), true, None)?;
531    let threshold = A::from(1e-15)
532        .ok_or_else(|| LinalgError::ComputationError("Failed to convert threshold".to_string()))?
533        * s[[0]];
534    let s_inv: Array1<A> = s.map(|&val| {
535        if val > threshold {
536            A::one() / val
537        } else {
538            A::zero()
539        }
540    });
541    Ok(vt.t().dot(&Array2::from_diag(&s_inv)).dot(&u.t()))
542}
543
544/// Matrix rank
545pub fn matrix_rank<A, S>(a: &ArrayBase<S, Ix2>, tol: Option<A>) -> LinalgResult<usize>
546where
547    A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
548    S: Data<Elem = A>,
549{
550    let (_, s, _) = crate::svd(&a.view(), false, None)?;
551    let threshold = tol.unwrap_or_else(|| {
552        let max_singular = s.iter().fold(A::zero(), |a, &b| if b > a { b } else { a });
553        let dim_factor = A::from(a.nrows().max(a.ncols())).unwrap_or_else(|| A::one());
554        max_singular * dim_factor * A::epsilon()
555    });
556    Ok(s.iter().filter(|&&val| val > threshold).count())
557}
558
559/// Condition number
560pub fn cond<A, S>(a: &ArrayBase<S, Ix2>) -> LinalgResult<A>
561where
562    A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
563    S: Data<Elem = A>,
564{
565    let (_, s, _) = crate::svd(&a.view(), false, None)?;
566    let s_max = s.iter().fold(A::zero(), |a, &b| if b > a { b } else { a });
567    let s_min = s
568        .iter()
569        .fold(s_max, |a, &b| if b < a && b > A::zero() { b } else { a });
570    if s_min == A::zero() {
571        return Ok(A::infinity());
572    }
573    Ok(s_max / s_min)
574}
575
576/// Matrix norm
577pub fn norm<A, S>(a: &ArrayBase<S, Ix2>, ord: Option<&str>) -> LinalgResult<A>
578where
579    A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
580    S: Data<Elem = A>,
581{
582    match ord {
583        None | Some("fro") => Ok(ArrayLinalgExt::norm_fro(a)),
584        Some("2") => {
585            let (_, s, _) = crate::svd(&a.view(), false, None)?;
586            Ok(s[[0]])
587        }
588        _ => Err(LinalgError::ComputationError(format!(
589            "norm ord={:?} not implemented",
590            ord
591        ))),
592    }
593}
594
595/// Vector norm
596pub fn vector_norm<A, S>(a: &ArrayBase<S, Ix1>, ord: Option<i32>) -> A
597where
598    A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
599    S: Data<Elem = A>,
600{
601    match ord {
602        None | Some(2) => a.iter().map(|&x| x * x).sum::<A>().sqrt(),
603        Some(1) => a.iter().map(|&x| x.abs()).sum::<A>(),
604        _ => a.iter().map(|&x| x * x).sum::<A>().sqrt(), // Default to L2
605    }
606}
607
608/// Schur decomposition
609pub fn schur<A, S>(a: &ArrayBase<S, Ix2>) -> LinalgResult<(Array2<A>, Array2<A>)>
610where
611    A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
612    S: Data<Elem = A>,
613{
614    crate::schur(&a.view())
615}
616
617/// Polar decomposition: A = UP where U is unitary and P is positive semidefinite
618pub fn polar<A, S>(a: &ArrayBase<S, Ix2>) -> LinalgResult<(Array2<A>, Array2<A>)>
619where
620    A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
621    S: Data<Elem = A>,
622{
623    let (u, s, vt) = crate::svd(&a.view(), true, None)?;
624    let unitary = u.dot(&vt);
625    let hermitian = vt.t().dot(&Array2::from_diag(&s)).dot(&vt);
626    Ok((unitary, hermitian))
627}
628
629/// Matrix exponential: exp(A)
630pub fn expm<A, S>(a: &ArrayBase<S, Ix2>) -> LinalgResult<Array2<A>>
631where
632    A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
633    S: Data<Elem = A>,
634{
635    crate::expm(&a.view(), None)
636}
637
638/// Matrix logarithm: log(A)
639pub fn logm<A, S>(a: &ArrayBase<S, Ix2>) -> LinalgResult<Array2<A>>
640where
641    A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
642    S: Data<Elem = A>,
643{
644    crate::logm(&a.view())
645}
646
647/// Matrix square root: sqrt(A)
648pub fn sqrtm<A, S>(a: &ArrayBase<S, Ix2>) -> LinalgResult<Array2<A>>
649where
650    A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
651    S: Data<Elem = A>,
652{
653    let tol = A::from(1e-8)
654        .ok_or_else(|| LinalgError::ComputationError("Failed to convert tolerance".to_string()))?;
655    crate::sqrtm(&a.view(), 100, tol)
656}
657
658/// Matrix sine: sin(A)
659pub fn sinm<A, S>(a: &ArrayBase<S, Ix2>) -> LinalgResult<Array2<A>>
660where
661    A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
662    S: Data<Elem = A>,
663{
664    crate::sinm(&a.view())
665}
666
667/// Matrix cosine: cos(A)
668pub fn cosm<A, S>(a: &ArrayBase<S, Ix2>) -> LinalgResult<Array2<A>>
669where
670    A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
671    S: Data<Elem = A>,
672{
673    crate::cosm(&a.view())
674}
675
676/// Matrix tangent: tan(A)
677pub fn tanm<A, S>(a: &ArrayBase<S, Ix2>) -> LinalgResult<Array2<A>>
678where
679    A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
680    S: Data<Elem = A>,
681{
682    crate::tanm(&a.view())
683}
684
685/// General matrix function: f(A) using eigendecomposition
686pub fn funm<A, S, F>(a: &ArrayBase<S, Ix2>, func: F) -> LinalgResult<Array2<A>>
687where
688    A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
689    S: Data<Elem = A>,
690    F: Fn(A) -> A,
691{
692    // For symmetric matrices, use eigendecomposition
693    let (vals, vecs) = crate::eigh(&a.view(), None)?;
694    let f_vals: Array1<A> = vals.map(|&v| func(v));
695    Ok(vecs.dot(&Array2::from_diag(&f_vals)).dot(&vecs.t()))
696}
697
698/// Fractional matrix power: A^p using eigendecomposition
699pub fn fractionalmatrix_power<A, S>(a: &ArrayBase<S, Ix2>, p: A) -> LinalgResult<Array2<A>>
700where
701    A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
702    S: Data<Elem = A>,
703{
704    funm(a, |x| x.powf(p))
705}
706
707/// Block diagonal matrix construction
708pub fn block_diag<A>(blocks: &[Array2<A>]) -> Array2<A>
709where
710    A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static + Zero,
711{
712    if blocks.is_empty() {
713        return Array2::zeros((0, 0));
714    }
715
716    let total_rows: usize = blocks.iter().map(|b| b.nrows()).sum();
717    let total_cols: usize = blocks.iter().map(|b| b.ncols()).sum();
718    let mut result = Array2::zeros((total_rows, total_cols));
719
720    let mut row_offset = 0;
721    let mut col_offset = 0;
722
723    for block in blocks {
724        let nrows = block.nrows();
725        let ncols = block.ncols();
726        result
727            .slice_mut(scirs2_core::ndarray::s![
728                row_offset..row_offset + nrows,
729                col_offset..col_offset + ncols
730            ])
731            .assign(block);
732        row_offset += nrows;
733        col_offset += ncols;
734    }
735
736    result
737}