1use crate::{LinalgError, LinalgResult};
21use scirs2_core::ndarray::{Array1, Array2, ArrayBase, Data, Ix1, Ix2, ScalarOperand};
22use scirs2_core::numeric::{Float, NumAssign, Zero};
23use scirs2_core::Complex;
24use std::iter::Sum;
25
26#[derive(Debug, Clone, Copy, PartialEq, Eq)]
28pub enum UPLO {
29 Upper,
31 Lower,
33}
34
35pub trait ArrayLinalgExt<A, S: scirs2_core::ndarray::RawData> {
37 fn svd(&self, compute_uv: bool) -> LinalgResult<(Array2<A>, Array1<A>, Array2<A>)>;
39
40 #[allow(clippy::type_complexity)]
42 fn eig(
43 &self,
44 ) -> LinalgResult<(
45 Array1<scirs2_core::Complex<A>>,
46 Array2<scirs2_core::Complex<A>>,
47 )>;
48
49 fn eigh(&self, uplo: UPLO) -> LinalgResult<(Array1<A>, Array2<A>)>;
51
52 fn eigvalsh(&self, uplo: UPLO) -> LinalgResult<Array1<A>>;
54
55 fn inv(&self) -> LinalgResult<Array2<A>>;
57
58 fn solve(&self, b: &ArrayBase<S, Ix1>) -> LinalgResult<Array1<A>>;
60
61 fn solve_into(&self, b: &ArrayBase<S, Ix2>) -> LinalgResult<Array2<A>>;
63
64 fn norm_l2(&self) -> A;
66
67 fn norm_fro(&self) -> A;
69
70 fn det(&self) -> LinalgResult<A>;
72
73 fn qr(&self) -> LinalgResult<(Array2<A>, Array2<A>)>;
75
76 fn lu(&self) -> LinalgResult<(Array2<A>, Array2<A>, Array2<A>)>;
78
79 fn cholesky(&self) -> LinalgResult<Array2<A>>;
81}
82
83impl<A, S> ArrayLinalgExt<A, S> for ArrayBase<S, Ix2>
84where
85 A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
86 S: Data<Elem = A>,
87{
88 fn svd(&self, compute_uv: bool) -> LinalgResult<(Array2<A>, Array1<A>, Array2<A>)> {
89 crate::svd(&self.view(), compute_uv, None)
90 }
91
92 fn eig(
93 &self,
94 ) -> LinalgResult<(
95 Array1<scirs2_core::Complex<A>>,
96 Array2<scirs2_core::Complex<A>>,
97 )> {
98 crate::eig(&self.view(), None)
99 }
100
101 fn eigh(&self, _uplo: UPLO) -> LinalgResult<(Array1<A>, Array2<A>)> {
102 crate::eigh(&self.view(), None)
104 }
105
106 fn eigvalsh(&self, _uplo: UPLO) -> LinalgResult<Array1<A>> {
107 crate::eigvalsh(&self.view(), None)
108 }
109
110 fn inv(&self) -> LinalgResult<Array2<A>> {
111 crate::inv(&self.view(), None)
112 }
113
114 fn solve(&self, b: &ArrayBase<S, Ix1>) -> LinalgResult<Array1<A>> {
115 crate::solve(&self.view(), &b.view(), None)
116 }
117
118 fn solve_into(&self, b: &ArrayBase<S, Ix2>) -> LinalgResult<Array2<A>> {
119 crate::solve_multiple(&self.view(), &b.view(), None)
120 }
121
122 fn norm_l2(&self) -> A {
123 self.iter().map(|&x| x * x).sum::<A>().sqrt()
125 }
126
127 fn norm_fro(&self) -> A {
128 self.iter().map(|&x| x * x).sum::<A>().sqrt()
129 }
130
131 fn det(&self) -> LinalgResult<A> {
132 crate::det(&self.view(), None)
133 }
134
135 fn qr(&self) -> LinalgResult<(Array2<A>, Array2<A>)> {
136 crate::qr(&self.view(), None)
137 }
138
139 fn lu(&self) -> LinalgResult<(Array2<A>, Array2<A>, Array2<A>)> {
140 crate::lu(&self.view(), None)
141 }
142
143 fn cholesky(&self) -> LinalgResult<Array2<A>> {
144 crate::cholesky(&self.view(), None)
145 }
146}
147
148pub trait Solve<A> {
150 type Output;
152
153 fn solve(&self, rhs: &Self) -> LinalgResult<Self::Output>;
155}
156
157pub trait SVD {
159 type S;
161 type U;
163 type Vt;
165
166 fn svd(&self, compute_uv: bool) -> LinalgResult<(Self::U, Self::S, Self::Vt)>;
168}
169
170impl<A, S> SVD for ArrayBase<S, Ix2>
171where
172 A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
173 S: Data<Elem = A>,
174{
175 type S = Array1<A>;
176 type U = Array2<A>;
177 type Vt = Array2<A>;
178
179 fn svd(&self, compute_uv: bool) -> LinalgResult<(Self::U, Self::S, Self::Vt)> {
180 ArrayLinalgExt::svd(self, compute_uv)
181 }
182}
183
184pub trait Eig {
186 type EigVal;
188 type EigVec;
190
191 fn eig(&self) -> LinalgResult<(Self::EigVal, Self::EigVec)>;
193}
194
195impl<A, S> Eig for ArrayBase<S, Ix2>
196where
197 A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
198 S: Data<Elem = A>,
199{
200 type EigVal = Array1<scirs2_core::Complex<A>>;
201 type EigVec = Array2<scirs2_core::Complex<A>>;
202
203 fn eig(&self) -> LinalgResult<(Self::EigVal, Self::EigVec)> {
204 ArrayLinalgExt::eig(self)
205 }
206}
207
208pub trait Eigh {
210 type EigVal;
212 type EigVec;
214
215 fn eigh(&self, uplo: UPLO) -> LinalgResult<(Self::EigVal, Self::EigVec)>;
217}
218
219impl<A, S> Eigh for ArrayBase<S, Ix2>
220where
221 A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
222 S: Data<Elem = A>,
223{
224 type EigVal = Array1<A>;
225 type EigVec = Array2<A>;
226
227 fn eigh(&self, uplo: UPLO) -> LinalgResult<(Self::EigVal, Self::EigVec)> {
228 ArrayLinalgExt::eigh(self, uplo)
229 }
230}
231
232pub trait Inverse {
234 type Output;
236
237 fn inv(&self) -> LinalgResult<Self::Output>;
239}
240
241impl<A, S> Inverse for ArrayBase<S, Ix2>
242where
243 A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
244 S: Data<Elem = A>,
245{
246 type Output = Array2<A>;
247
248 fn inv(&self) -> LinalgResult<Self::Output> {
249 ArrayLinalgExt::inv(self)
250 }
251}
252
253pub trait Norm<A> {
255 fn norm(&self) -> A;
257
258 fn norm_l2(&self) -> A;
260
261 fn norm_fro(&self) -> A;
263}
264
265impl<A, S> Norm<A> for ArrayBase<S, Ix2>
266where
267 A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
268 S: Data<Elem = A>,
269{
270 fn norm(&self) -> A {
271 ArrayLinalgExt::norm_fro(self)
272 }
273
274 fn norm_l2(&self) -> A {
275 ArrayLinalgExt::norm_l2(self)
276 }
277
278 fn norm_fro(&self) -> A {
279 ArrayLinalgExt::norm_fro(self)
280 }
281}
282
283impl<A, S> Norm<A> for ArrayBase<S, Ix1>
284where
285 A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
286 S: Data<Elem = A>,
287{
288 fn norm(&self) -> A {
289 self.norm_l2()
290 }
291
292 fn norm_l2(&self) -> A {
293 self.iter().map(|&x| x * x).sum::<A>().sqrt()
294 }
295
296 fn norm_fro(&self) -> A {
297 self.norm_l2()
298 }
299}
300
301pub type SvdResult<A> = (Array2<A>, Array1<A>, Array2<A>);
309
310pub fn svd<A, S>(a: &ArrayBase<S, Ix2>, compute_uv: bool) -> LinalgResult<SvdResult<A>>
312where
313 A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
314 S: Data<Elem = A>,
315{
316 crate::svd(&a.view(), compute_uv, None)
317}
318
319#[allow(clippy::type_complexity)]
321pub fn eig<A, S>(
322 a: &ArrayBase<S, Ix2>,
323) -> LinalgResult<(
324 Array1<scirs2_core::Complex<A>>,
325 Array2<scirs2_core::Complex<A>>,
326)>
327where
328 A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
329 S: Data<Elem = A>,
330{
331 crate::eig(&a.view(), None)
332}
333
334pub fn eigh<A, S>(a: &ArrayBase<S, Ix2>, uplo: UPLO) -> LinalgResult<(Array1<A>, Array2<A>)>
336where
337 A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
338 S: Data<Elem = A>,
339{
340 let _ = uplo; crate::eigh(&a.view(), None)
342}
343
344pub fn eigvalsh<A, S>(a: &ArrayBase<S, Ix2>, uplo: UPLO) -> LinalgResult<Array1<A>>
346where
347 A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
348 S: Data<Elem = A>,
349{
350 let _ = uplo; crate::eigvalsh(&a.view(), None)
352}
353
354pub fn eigvals<A, S>(a: &ArrayBase<S, Ix2>) -> LinalgResult<Array1<scirs2_core::Complex<A>>>
356where
357 A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
358 S: Data<Elem = A>,
359{
360 let (vals, _) = crate::eig(&a.view(), None)?;
361 Ok(vals)
362}
363
364pub fn eigvals_banded<A, S>(a: &ArrayBase<S, Ix2>, uplo: UPLO) -> LinalgResult<Array1<A>>
366where
367 A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
368 S: Data<Elem = A>,
369{
370 eigvalsh(a, uplo)
371}
372
373pub fn eig_banded<A, S>(a: &ArrayBase<S, Ix2>, uplo: UPLO) -> LinalgResult<(Array1<A>, Array2<A>)>
375where
376 A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
377 S: Data<Elem = A>,
378{
379 eigh(a, uplo)
380}
381
382pub fn eigh_tridiagonal<A>(d: &Array1<A>, e: &Array1<A>) -> LinalgResult<(Array1<A>, Array2<A>)>
384where
385 A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
386{
387 let n = d.len();
389 let mut mat = Array2::zeros((n, n));
390 for i in 0..n {
391 mat[[i, i]] = d[i];
392 if i < n - 1 {
393 mat[[i, i + 1]] = e[i];
394 mat[[i + 1, i]] = e[i];
395 }
396 }
397 eigh(&mat, UPLO::Lower)
398}
399
400pub fn eigvalsh_tridiagonal<A>(d: &Array1<A>, e: &Array1<A>) -> LinalgResult<Array1<A>>
402where
403 A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
404{
405 let (vals, _) = eigh_tridiagonal(d, e)?;
406 Ok(vals)
407}
408
409pub fn inv<A, S>(a: &ArrayBase<S, Ix2>) -> LinalgResult<Array2<A>>
411where
412 A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
413 S: Data<Elem = A>,
414{
415 crate::inv(&a.view(), None)
416}
417
418pub fn det<A, S>(a: &ArrayBase<S, Ix2>) -> LinalgResult<A>
420where
421 A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
422 S: Data<Elem = A>,
423{
424 crate::det(&a.view(), None)
425}
426
427pub fn qr<A, S>(a: &ArrayBase<S, Ix2>) -> LinalgResult<(Array2<A>, Array2<A>)>
429where
430 A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
431 S: Data<Elem = A>,
432{
433 crate::qr(&a.view(), None)
434}
435
436pub fn rq<A, S>(a: &ArrayBase<S, Ix2>) -> LinalgResult<(Array2<A>, Array2<A>)>
438where
439 A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
440 S: Data<Elem = A>,
441{
442 let t = a.t();
444 let (q, r) = crate::qr(&t.view(), None)?;
445 Ok((r.reversed_axes(), q.reversed_axes()))
446}
447
448pub fn lu<A, S>(a: &ArrayBase<S, Ix2>) -> LinalgResult<(Array2<A>, Array2<A>, Array2<A>)>
450where
451 A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
452 S: Data<Elem = A>,
453{
454 crate::lu(&a.view(), None)
455}
456
457pub fn cholesky<A, S>(a: &ArrayBase<S, Ix2>, uplo: UPLO) -> LinalgResult<Array2<A>>
459where
460 A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
461 S: Data<Elem = A>,
462{
463 let _ = uplo; crate::cholesky(&a.view(), None)
465}
466
467pub fn compat_solve<A, S1, S2>(
469 a: &ArrayBase<S1, Ix2>,
470 b: &ArrayBase<S2, Ix1>,
471) -> LinalgResult<Array1<A>>
472where
473 A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
474 S1: Data<Elem = A>,
475 S2: Data<Elem = A>,
476{
477 crate::solve(&a.view(), &b.view(), None)
478}
479
480pub fn solve_banded<A, S1, S2>(
482 l_and_u: (usize, usize),
483 ab: &ArrayBase<S1, Ix2>,
484 b: &ArrayBase<S2, Ix1>,
485) -> LinalgResult<Array1<A>>
486where
487 A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
488 S1: Data<Elem = A>,
489 S2: Data<Elem = A>,
490{
491 let (l, u) = l_and_u;
492 crate::structured_solvers::solve_banded(l, u, &ab.view(), &b.view())
493}
494
495pub fn solve_triangular<A, S1, S2>(
497 a: &ArrayBase<S1, Ix2>,
498 b: &ArrayBase<S2, Ix1>,
499 lower: bool,
500) -> LinalgResult<Array1<A>>
501where
502 A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
503 S1: Data<Elem = A>,
504 S2: Data<Elem = A>,
505{
506 let _ = (a, b, lower);
507 Err(LinalgError::ComputationError(
508 "solve_triangular not yet implemented".to_string(),
509 ))
510}
511
512pub fn lstsq<A, S1, S2>(a: &ArrayBase<S1, Ix2>, b: &ArrayBase<S2, Ix1>) -> LinalgResult<Array1<A>>
514where
515 A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
516 S1: Data<Elem = A>,
517 S2: Data<Elem = A>,
518{
519 let result = crate::lstsq(&a.view(), &b.view(), None)?;
520 Ok(result.x)
521}
522
523pub fn pinv<A, S>(a: &ArrayBase<S, Ix2>) -> LinalgResult<Array2<A>>
525where
526 A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
527 S: Data<Elem = A>,
528{
529 let (u, s, vt) = crate::svd(&a.view(), true, None)?;
531 let threshold = A::from(1e-15)
532 .ok_or_else(|| LinalgError::ComputationError("Failed to convert threshold".to_string()))?
533 * s[[0]];
534 let s_inv: Array1<A> = s.map(|&val| {
535 if val > threshold {
536 A::one() / val
537 } else {
538 A::zero()
539 }
540 });
541 Ok(vt.t().dot(&Array2::from_diag(&s_inv)).dot(&u.t()))
542}
543
544pub fn matrix_rank<A, S>(a: &ArrayBase<S, Ix2>, tol: Option<A>) -> LinalgResult<usize>
546where
547 A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
548 S: Data<Elem = A>,
549{
550 let (_, s, _) = crate::svd(&a.view(), false, None)?;
551 let threshold = tol.unwrap_or_else(|| {
552 let max_singular = s.iter().fold(A::zero(), |a, &b| if b > a { b } else { a });
553 let dim_factor = A::from(a.nrows().max(a.ncols())).unwrap_or_else(|| A::one());
554 max_singular * dim_factor * A::epsilon()
555 });
556 Ok(s.iter().filter(|&&val| val > threshold).count())
557}
558
559pub fn cond<A, S>(a: &ArrayBase<S, Ix2>) -> LinalgResult<A>
561where
562 A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
563 S: Data<Elem = A>,
564{
565 let (_, s, _) = crate::svd(&a.view(), false, None)?;
566 let s_max = s.iter().fold(A::zero(), |a, &b| if b > a { b } else { a });
567 let s_min = s
568 .iter()
569 .fold(s_max, |a, &b| if b < a && b > A::zero() { b } else { a });
570 if s_min == A::zero() {
571 return Ok(A::infinity());
572 }
573 Ok(s_max / s_min)
574}
575
576pub fn norm<A, S>(a: &ArrayBase<S, Ix2>, ord: Option<&str>) -> LinalgResult<A>
578where
579 A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
580 S: Data<Elem = A>,
581{
582 match ord {
583 None | Some("fro") => Ok(ArrayLinalgExt::norm_fro(a)),
584 Some("2") => {
585 let (_, s, _) = crate::svd(&a.view(), false, None)?;
586 Ok(s[[0]])
587 }
588 _ => Err(LinalgError::ComputationError(format!(
589 "norm ord={:?} not implemented",
590 ord
591 ))),
592 }
593}
594
595pub fn vector_norm<A, S>(a: &ArrayBase<S, Ix1>, ord: Option<i32>) -> A
597where
598 A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
599 S: Data<Elem = A>,
600{
601 match ord {
602 None | Some(2) => a.iter().map(|&x| x * x).sum::<A>().sqrt(),
603 Some(1) => a.iter().map(|&x| x.abs()).sum::<A>(),
604 _ => a.iter().map(|&x| x * x).sum::<A>().sqrt(), }
606}
607
608pub fn schur<A, S>(a: &ArrayBase<S, Ix2>) -> LinalgResult<(Array2<A>, Array2<A>)>
610where
611 A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
612 S: Data<Elem = A>,
613{
614 crate::schur(&a.view())
615}
616
617pub fn polar<A, S>(a: &ArrayBase<S, Ix2>) -> LinalgResult<(Array2<A>, Array2<A>)>
619where
620 A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
621 S: Data<Elem = A>,
622{
623 let (u, s, vt) = crate::svd(&a.view(), true, None)?;
624 let unitary = u.dot(&vt);
625 let hermitian = vt.t().dot(&Array2::from_diag(&s)).dot(&vt);
626 Ok((unitary, hermitian))
627}
628
629pub fn expm<A, S>(a: &ArrayBase<S, Ix2>) -> LinalgResult<Array2<A>>
631where
632 A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
633 S: Data<Elem = A>,
634{
635 crate::expm(&a.view(), None)
636}
637
638pub fn logm<A, S>(a: &ArrayBase<S, Ix2>) -> LinalgResult<Array2<A>>
640where
641 A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
642 S: Data<Elem = A>,
643{
644 crate::logm(&a.view())
645}
646
647pub fn sqrtm<A, S>(a: &ArrayBase<S, Ix2>) -> LinalgResult<Array2<A>>
649where
650 A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
651 S: Data<Elem = A>,
652{
653 let tol = A::from(1e-8)
654 .ok_or_else(|| LinalgError::ComputationError("Failed to convert tolerance".to_string()))?;
655 crate::sqrtm(&a.view(), 100, tol)
656}
657
658pub fn sinm<A, S>(a: &ArrayBase<S, Ix2>) -> LinalgResult<Array2<A>>
660where
661 A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
662 S: Data<Elem = A>,
663{
664 crate::sinm(&a.view())
665}
666
667pub fn cosm<A, S>(a: &ArrayBase<S, Ix2>) -> LinalgResult<Array2<A>>
669where
670 A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
671 S: Data<Elem = A>,
672{
673 crate::cosm(&a.view())
674}
675
676pub fn tanm<A, S>(a: &ArrayBase<S, Ix2>) -> LinalgResult<Array2<A>>
678where
679 A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
680 S: Data<Elem = A>,
681{
682 crate::tanm(&a.view())
683}
684
685pub fn funm<A, S, F>(a: &ArrayBase<S, Ix2>, func: F) -> LinalgResult<Array2<A>>
687where
688 A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
689 S: Data<Elem = A>,
690 F: Fn(A) -> A,
691{
692 let (vals, vecs) = crate::eigh(&a.view(), None)?;
694 let f_vals: Array1<A> = vals.map(|&v| func(v));
695 Ok(vecs.dot(&Array2::from_diag(&f_vals)).dot(&vecs.t()))
696}
697
698pub fn fractionalmatrix_power<A, S>(a: &ArrayBase<S, Ix2>, p: A) -> LinalgResult<Array2<A>>
700where
701 A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
702 S: Data<Elem = A>,
703{
704 funm(a, |x| x.powf(p))
705}
706
707pub fn block_diag<A>(blocks: &[Array2<A>]) -> Array2<A>
709where
710 A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static + Zero,
711{
712 if blocks.is_empty() {
713 return Array2::zeros((0, 0));
714 }
715
716 let total_rows: usize = blocks.iter().map(|b| b.nrows()).sum();
717 let total_cols: usize = blocks.iter().map(|b| b.ncols()).sum();
718 let mut result = Array2::zeros((total_rows, total_cols));
719
720 let mut row_offset = 0;
721 let mut col_offset = 0;
722
723 for block in blocks {
724 let nrows = block.nrows();
725 let ncols = block.ncols();
726 result
727 .slice_mut(scirs2_core::ndarray::s![
728 row_offset..row_offset + nrows,
729 col_offset..col_offset + ncols
730 ])
731 .assign(block);
732 row_offset += nrows;
733 col_offset += ncols;
734 }
735
736 result
737}