1pub use ::ndarray::*;
33
34#[cfg(feature = "random")]
41pub use ndarray_rand::{rand_distr as distributions, RandomExt, SamplingStrategy};
42
43#[cfg(feature = "array_stats")]
48pub use ndarray_stats::{
49 errors as stats_errors, interpolate, CorrelationExt, DeviationExt, MaybeNan, QuantileExt,
50 Sort1dExt, SummaryStatisticsExt,
51};
52
53pub mod utils {
61 use super::*;
62
63 pub fn eye<A>(n: usize) -> Array2<A>
65 where
66 A: Clone + num_traits::Zero + num_traits::One,
67 {
68 let mut arr = Array2::zeros((n, n));
69 for i in 0..n {
70 arr[[i, i]] = A::one();
71 }
72 arr
73 }
74
75 pub fn diag<A>(v: &Array1<A>) -> Array2<A>
77 where
78 A: Clone + num_traits::Zero,
79 {
80 let n = v.len();
81 let mut arr = Array2::zeros((n, n));
82 for i in 0..n {
83 arr[[i, i]] = v[i].clone();
84 }
85 arr
86 }
87
88 pub fn allclose<A, D>(
90 a: &ArrayBase<impl Data<Elem = A>, D>,
91 b: &ArrayBase<impl Data<Elem = A>, D>,
92 rtol: A,
93 atol: A,
94 ) -> bool
95 where
96 A: PartialOrd
97 + std::ops::Sub<Output = A>
98 + std::ops::Mul<Output = A>
99 + std::ops::Add<Output = A>
100 + Clone,
101 D: Dimension,
102 {
103 if a.shape() != b.shape() {
104 return false;
105 }
106
107 a.iter().zip(b.iter()).all(|(a_val, b_val)| {
108 let diff = if a_val > b_val {
109 a_val.clone() - b_val.clone()
110 } else {
111 b_val.clone() - a_val.clone()
112 };
113
114 let threshold = atol.clone()
115 + rtol.clone()
116 * (if a_val > b_val {
117 a_val.clone()
118 } else {
119 b_val.clone()
120 });
121
122 diff <= threshold
123 })
124 }
125
126 pub fn concatenate<A, D>(
128 axis: Axis,
129 arrays: &[ArrayView<A, D>],
130 ) -> Result<Array<A, D>, ShapeError>
131 where
132 A: Clone,
133 D: Dimension + RemoveAxis,
134 {
135 ndarray::concatenate(axis, arrays)
136 }
137
138 pub fn stack<A, D>(
140 axis: Axis,
141 arrays: &[ArrayView<A, D>],
142 ) -> Result<Array<A, D::Larger>, ShapeError>
143 where
144 A: Clone,
145 D: Dimension,
146 D::Larger: RemoveAxis,
147 {
148 ndarray::stack(axis, arrays)
149 }
150}
151
152pub mod compat {
159 pub use super::*;
160 use crate::numeric::{Float, FromPrimitive};
161
162 pub type DynArray<T> = ArrayD<T>;
164 pub type Matrix<T> = Array2<T>;
165 pub type Vector<T> = Array1<T>;
166 pub type Tensor3<T> = Array3<T>;
167 pub type Tensor4<T> = Array4<T>;
168
169 pub trait ArrayStatCompat<T> {
192 fn mean_or(&self, default: T) -> T;
198
199 fn var_or(&self, ddof: T, default: T) -> T;
201
202 fn std_or(&self, ddof: T, default: T) -> T;
204 }
205
206 impl<T, S, D> ArrayStatCompat<T> for ArrayBase<S, D>
207 where
208 T: Float + FromPrimitive,
209 S: Data<Elem = T>,
210 D: Dimension,
211 {
212 fn mean_or(&self, default: T) -> T {
213 self.mean().unwrap_or(default)
215 }
216
217 fn var_or(&self, ddof: T, default: T) -> T {
218 let v = self.var(ddof);
220 if v.is_nan() {
221 default
222 } else {
223 v
224 }
225 }
226
227 fn std_or(&self, ddof: T, default: T) -> T {
228 let s = self.std(ddof);
230 if s.is_nan() {
231 default
232 } else {
233 s
234 }
235 }
236 }
237
238 pub use crate::ndarray_ext::{
240 broadcast_1d_to_2d,
241 broadcast_apply,
242 fancy_index_2d,
243 indexing,
245 is_broadcast_compatible,
246 manipulation,
247 mask_select,
248 matrix,
249 reshape_2d,
250 split_2d,
251 stack_2d,
252 stats,
253 take_2d,
254 transpose_2d,
255 where_condition,
256 };
257}
258
259pub mod prelude {
265 pub use super::{
266 arr1,
267 arr2,
268 array,
270 azip,
271 concatenate,
273 s,
274 stack,
275
276 stack as stack_fn,
277 Array,
279 Array0,
280 Array1,
281 Array2,
282 Array3,
283 ArrayD,
284 ArrayView,
285 ArrayView1,
286 ArrayView2,
287 ArrayViewMut,
288
289 Axis,
291 Dimension,
293 Ix1,
294 Ix2,
295 Ix3,
296 IxDyn,
297 ScalarOperand,
298 ShapeBuilder,
299
300 Zip,
301 };
302
303 #[cfg(feature = "random")]
304 pub use super::RandomExt;
305
306 pub type Matrix<T> = super::Array2<T>;
308 pub type Vector<T> = super::Array1<T>;
309}
310
311#[cfg(test)]
316pub mod examples {
317 use super::*;
320
321 #[test]
347 fn test_complete_functionality() {
348 let a = array![[1., 2.], [3., 4.]];
350 assert_eq!(a.shape(), &[2, 2]);
351
352 let slice = a.slice(s![.., 0]);
354 assert_eq!(slice.len(), 2);
355
356 let b = &a + &a;
358 assert_eq!(b[[0, 0]], 2.);
359
360 let sum = a.sum_axis(Axis(0));
362 assert_eq!(sum.len(), 2);
363
364 let c = array![1., 2.];
366 let d = &a + &c;
367 assert_eq!(d[[0, 0]], 2.);
368 }
369}
370
371pub mod migration_guide {
376 }
416
417pub use compat::ArrayStatCompat;
419
420#[cfg(test)]
421mod tests {
422 use super::*;
423
424 #[test]
425 fn test_array_macro_available() {
426 let arr = array![[1, 2], [3, 4]];
427 assert_eq!(arr.shape(), &[2, 2]);
428 assert_eq!(arr[[0, 0]], 1);
429 }
430
431 #[test]
432 fn test_s_macro_available() {
433 let arr = array![[1, 2, 3], [4, 5, 6]];
434 let slice = arr.slice(s![.., 1..]);
435 assert_eq!(slice.shape(), &[2, 2]);
436 }
437
438 #[test]
439 fn test_axis_operations() {
440 let arr = array![[1., 2.], [3., 4.]];
441 let sum = arr.sum_axis(Axis(0));
442 assert_eq!(sum, array![4., 6.]);
443 }
444
445 #[test]
446 fn test_views_and_iteration() {
447 let mut arr = array![[1, 2], [3, 4]];
448
449 {
451 let view: ArrayView<_, _> = arr.view();
452 for val in view.iter() {
453 assert!(*val > 0);
454 }
455 }
456
457 {
459 let mut view_mut: ArrayViewMut<_, _> = arr.view_mut();
460 for val in view_mut.iter_mut() {
461 *val *= 2;
462 }
463 }
464
465 assert_eq!(arr[[0, 0]], 2);
466 }
467
468 #[test]
469 fn test_concatenate_and_stack() {
470 let a = array![[1, 2], [3, 4]];
471 let b = array![[5, 6], [7, 8]];
472
473 let concat = concatenate(Axis(0), &[a.view(), b.view()]).expect("Operation failed");
475 assert_eq!(concat.shape(), &[4, 2]);
476
477 let stacked =
479 crate::ndarray::stack(Axis(0), &[a.view(), b.view()]).expect("Operation failed");
480 assert_eq!(stacked.shape(), &[2, 2, 2]);
481 }
482
483 #[test]
484 fn test_zip_operations() {
485 let a = array![1, 2, 3];
486 let b = array![4, 5, 6];
487 let mut c = array![0, 0, 0];
488
489 azip!((a in &a, b in &b, c in &mut c) {
490 *c = a + b;
491 });
492
493 assert_eq!(c, array![5, 7, 9]);
494 }
495
496 #[test]
497 fn test_array_stat_compat() {
498 use compat::ArrayStatCompat;
499
500 let data = array![1.0, 2.0, 3.0, 4.0, 5.0];
502 assert_eq!(data.mean_or(0.0), 3.0);
503
504 let empty = Array1::<f64>::from(vec![]);
505 assert_eq!(empty.mean_or(0.0), 0.0);
506
507 let data = array![1.0, 2.0, 3.0, 4.0, 5.0];
509 let var = data.var_or(1.0, 0.0);
510 assert!(var > 0.0);
511
512 let std = data.std_or(1.0, 0.0);
514 assert!(std > 0.0);
515 }
516}
517
518#[cfg(feature = "linalg")]
526pub mod ndarray_linalg {
527 use crate::linalg::prelude::*;
528 use crate::ndarray::*;
529 use num_complex::Complex;
530
531 use oxiblas_ndarray::lapack::{
533 cholesky_hermitian_ndarray, eig_hermitian_ndarray, qr_complex_ndarray, svd_complex_ndarray,
534 };
535
536 pub use crate::linalg::{LapackError, LapackResult};
538
539 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
541 pub enum UPLO {
542 Upper,
543 Lower,
544 }
545
546 pub trait Solve<A> {
548 fn solve_into(&self, b: &Array1<A>) -> Result<Array1<A>, LapackError>;
549 }
550
551 impl Solve<f64> for Array2<f64> {
552 #[inline]
553 fn solve_into(&self, b: &Array1<f64>) -> Result<Array1<f64>, LapackError> {
554 solve_ndarray(self, b)
555 }
556 }
557
558 impl Solve<Complex<f64>> for Array2<Complex<f64>> {
559 #[inline]
560 fn solve_into(
561 &self,
562 b: &Array1<Complex<f64>>,
563 ) -> Result<Array1<Complex<f64>>, LapackError> {
564 solve_ndarray(self, b)
565 }
566 }
567
568 pub trait SVD {
570 type Elem;
571 type Real;
572
573 fn svd(
574 &self,
575 compute_u: bool,
576 compute_vt: bool,
577 ) -> Result<(Array2<Self::Elem>, Array1<Self::Real>, Array2<Self::Elem>), LapackError>;
578 }
579
580 impl SVD for Array2<f64> {
581 type Elem = f64;
582 type Real = f64;
583
584 #[inline]
585 fn svd(
586 &self,
587 _compute_u: bool,
588 _compute_vt: bool,
589 ) -> Result<(Array2<f64>, Array1<f64>, Array2<f64>), LapackError> {
590 let result = svd_ndarray(self)?;
591 Ok((result.u, result.s, result.vt))
592 }
593 }
594
595 impl SVD for Array2<Complex<f64>> {
596 type Elem = Complex<f64>;
597 type Real = f64;
598
599 #[inline]
600 fn svd(
601 &self,
602 _compute_u: bool,
603 _compute_vt: bool,
604 ) -> Result<(Array2<Complex<f64>>, Array1<f64>, Array2<Complex<f64>>), LapackError>
605 {
606 let result = svd_complex_ndarray(self)?;
607 Ok((result.u, result.s, result.vt))
608 }
609 }
610
611 pub trait Eigh {
613 type Elem;
614 type Real;
615
616 fn eigh(&self, uplo: UPLO)
617 -> Result<(Array1<Self::Real>, Array2<Self::Elem>), LapackError>;
618 }
619
620 impl Eigh for Array2<f64> {
621 type Elem = f64;
622 type Real = f64;
623
624 #[inline]
625 fn eigh(&self, _uplo: UPLO) -> Result<(Array1<f64>, Array2<f64>), LapackError> {
626 let result = eig_symmetric(self)?;
627 Ok((result.eigenvalues, result.eigenvectors))
628 }
629 }
630
631 impl Eigh for Array2<Complex<f64>> {
632 type Elem = Complex<f64>;
633 type Real = f64;
634
635 #[inline]
636 fn eigh(&self, _uplo: UPLO) -> Result<(Array1<f64>, Array2<Complex<f64>>), LapackError> {
637 eig_hermitian_ndarray(self)
638 }
639 }
640
641 pub trait Norm {
643 type Real;
644
645 fn norm_l2(&self) -> Result<Self::Real, LapackError>;
646 }
647
648 impl Norm for Array2<f64> {
649 type Real = f64;
650
651 #[inline]
652 fn norm_l2(&self) -> Result<f64, LapackError> {
653 let sum_sq: f64 = self.iter().map(|x| x * x).sum();
654 Ok(sum_sq.sqrt())
655 }
656 }
657
658 impl Norm for Array2<Complex<f64>> {
659 type Real = f64;
660
661 #[inline]
662 fn norm_l2(&self) -> Result<f64, LapackError> {
663 let sum_sq: f64 = self.iter().map(|x| x.norm_sqr()).sum();
664 Ok(sum_sq.sqrt())
665 }
666 }
667
668 impl Norm for Array1<f64> {
670 type Real = f64;
671
672 #[inline]
673 fn norm_l2(&self) -> Result<f64, LapackError> {
674 let sum_sq: f64 = self.iter().map(|x| x * x).sum();
675 Ok(sum_sq.sqrt())
676 }
677 }
678
679 impl Norm for Array1<Complex<f64>> {
680 type Real = f64;
681
682 #[inline]
683 fn norm_l2(&self) -> Result<f64, LapackError> {
684 let sum_sq: f64 = self.iter().map(|x| x.norm_sqr()).sum();
685 Ok(sum_sq.sqrt())
686 }
687 }
688
689 pub trait QR {
691 type Elem;
692
693 fn qr(&self) -> Result<(Array2<Self::Elem>, Array2<Self::Elem>), LapackError>;
694 }
695
696 impl QR for Array2<f64> {
697 type Elem = f64;
698
699 #[inline]
700 fn qr(&self) -> Result<(Array2<f64>, Array2<f64>), LapackError> {
701 let result = qr_ndarray(self)?;
702 Ok((result.q, result.r))
703 }
704 }
705
706 impl QR for Array2<Complex<f64>> {
707 type Elem = Complex<f64>;
708
709 #[inline]
710 fn qr(&self) -> Result<(Array2<Complex<f64>>, Array2<Complex<f64>>), LapackError> {
711 let result = qr_complex_ndarray(self)?;
712 Ok((result.q, result.r))
713 }
714 }
715
716 pub trait Eig {
718 type Elem;
719
720 fn eig(&self) -> Result<(Array1<Self::Elem>, Array2<Self::Elem>), LapackError>;
721 }
722
723 impl Eig for Array2<Complex<f64>> {
731 type Elem = Complex<f64>;
732
733 fn eig(&self) -> Result<(Array1<Complex<f64>>, Array2<Complex<f64>>), LapackError> {
734 let (m, n) = self.dim();
735 if m != n {
736 return Err(LapackError::DimensionMismatch(
737 "Matrix must be square for eigendecomposition".to_string(),
738 ));
739 }
740 if n == 0 {
741 return Ok((
742 Array1::<Complex<f64>>::zeros(0),
743 Array2::<Complex<f64>>::zeros((0, 0)),
744 ));
745 }
746 if n == 1 {
747 let eigenvalue = self[[0, 0]];
748 let eigenvector = Array2::from_elem((1, 1), Complex::new(1.0, 0.0));
749 return Ok((Array1::from_vec(vec![eigenvalue]), eigenvector));
750 }
751
752 let mut h = self.clone();
754 let mut q = Array2::<Complex<f64>>::eye(n);
755
756 for col in 0..n.saturating_sub(2) {
757 let xlen = n - col - 1;
758 if xlen == 0 {
759 continue;
760 }
761
762 let mut x: Vec<Complex<f64>> = (col + 1..n).map(|r| h[[r, col]]).collect();
763
764 let norm_x = x.iter().map(|v| v.norm_sqr()).sum::<f64>().sqrt();
765 if norm_x < 1e-300 {
766 continue;
767 }
768
769 let phase = if x[0].norm() > 1e-300 {
770 x[0] / x[0].norm()
771 } else {
772 Complex::new(1.0, 0.0)
773 };
774 x[0] += phase * norm_x;
775
776 let norm_v = x.iter().map(|v| v.norm_sqr()).sum::<f64>().sqrt();
777 if norm_v < 1e-300 {
778 continue;
779 }
780 let v: Vec<Complex<f64>> = x.iter().map(|vi| *vi / norm_v).collect();
781
782 for c in col..n {
784 let dot: Complex<f64> = v
785 .iter()
786 .enumerate()
787 .map(|(i, &vi)| vi.conj() * h[[col + 1 + i, c]])
788 .sum();
789 for (i, &vi) in v.iter().enumerate() {
790 h[[col + 1 + i, c]] -= Complex::new(2.0, 0.0) * vi * dot;
791 }
792 }
793
794 for r in 0..n {
796 let dot: Complex<f64> = v
797 .iter()
798 .enumerate()
799 .map(|(i, &vi)| h[[r, col + 1 + i]] * vi)
800 .sum();
801 for (i, &vi) in v.iter().enumerate() {
802 h[[r, col + 1 + i]] -= Complex::new(2.0, 0.0) * dot * vi.conj();
803 }
804 }
805
806 for r in 0..n {
808 let dot: Complex<f64> = v
809 .iter()
810 .enumerate()
811 .map(|(i, &vi)| q[[r, col + 1 + i]] * vi)
812 .sum();
813 for (i, &vi) in v.iter().enumerate() {
814 q[[r, col + 1 + i]] -= Complex::new(2.0, 0.0) * dot * vi.conj();
815 }
816 }
817
818 for r in col + 2..n {
820 h[[r, col]] = Complex::new(0.0, 0.0);
821 }
822 }
823
824 const MAX_ITER: usize = 30;
826 let mut p = n;
827
828 'outer: while p > 1 {
829 let mut deflated = false;
831 for l in (1..p).rev() {
832 let sub = h[[l, l - 1]].norm();
833 let diag = h[[l - 1, l - 1]].norm() + h[[l, l]].norm();
834 if sub <= 1e-14 * diag || sub <= f64::MIN_POSITIVE.sqrt() {
835 h[[l, l - 1]] = Complex::new(0.0, 0.0);
836 if l == p - 1 {
837 p -= 1;
838 deflated = true;
839 break;
840 }
841 }
842 }
843 if deflated {
844 continue 'outer;
845 }
846
847 let mut converged_inner = false;
848 for _iter in 0..MAX_ITER {
849 let a_sub = h[[p - 2, p - 2]];
851 let b_sub = h[[p - 2, p - 1]];
852 let c_sub = h[[p - 1, p - 2]];
853 let d_sub = h[[p - 1, p - 1]];
854 let tr = a_sub + d_sub;
855 let det = a_sub * d_sub - b_sub * c_sub;
856 let disc = (tr * tr - Complex::new(4.0, 0.0) * det).sqrt();
857 let mu1 = (tr + disc) * Complex::new(0.5, 0.0);
858 let mu2 = (tr - disc) * Complex::new(0.5, 0.0);
859 let shift = if (mu1 - d_sub).norm() < (mu2 - d_sub).norm() {
860 mu1
861 } else {
862 mu2
863 };
864
865 for k in 0..p.saturating_sub(1) {
867 let a_g = if k == 0 {
868 h[[0, 0]] - shift
869 } else {
870 h[[k, k - 1]]
871 };
872 let b_g = h[[k + 1, k]];
873 let r = (a_g.norm_sqr() + b_g.norm_sqr()).sqrt();
874 if r < 1e-300 {
875 continue;
876 }
877 let c = a_g / r;
878 let s = b_g / r;
879
880 let col_start = if k == 0 { 0 } else { k - 1 };
882 for j in col_start..n {
883 let t1 = c.conj() * h[[k, j]] + s.conj() * h[[k + 1, j]];
884 let t2 = -s * h[[k, j]] + c * h[[k + 1, j]];
885 h[[k, j]] = t1;
886 h[[k + 1, j]] = t2;
887 }
888
889 let row_max = (k + 2).min(p);
891 for i in 0..row_max {
892 let t1 = h[[i, k]] * c + h[[i, k + 1]] * s;
893 let t2 = h[[i, k]] * (-s.conj()) + h[[i, k + 1]] * c.conj();
894 h[[i, k]] = t1;
895 h[[i, k + 1]] = t2;
896 }
897
898 for i in 0..n {
900 let t1 = q[[i, k]] * c + q[[i, k + 1]] * s;
901 let t2 = q[[i, k]] * (-s.conj()) + q[[i, k + 1]] * c.conj();
902 q[[i, k]] = t1;
903 q[[i, k + 1]] = t2;
904 }
905 }
906
907 let sub_norm = h[[p - 1, p - 2]].norm();
908 let diag_norm = h[[p - 2, p - 2]].norm() + h[[p - 1, p - 1]].norm();
909 if sub_norm <= 1e-14 * diag_norm || sub_norm <= f64::MIN_POSITIVE.sqrt() {
910 h[[p - 1, p - 2]] = Complex::new(0.0, 0.0);
911 p -= 1;
912 converged_inner = true;
913 break;
914 }
915 }
916
917 if !converged_inner {
918 p -= 1; }
920 }
921
922 let eigenvalues: Array1<Complex<f64>> = Array1::from_iter((0..n).map(|i| h[[i, i]]));
924
925 let mut vecs = Array2::<Complex<f64>>::zeros((n, n));
928 for ei in 0..n {
929 let lambda = eigenvalues[ei];
930 let mut v = vec![Complex::new(0.0, 0.0); n];
931 v[ei] = Complex::new(1.0, 0.0);
932
933 for row in (0..ei).rev() {
934 let mut sum = Complex::new(0.0, 0.0);
935 for col in row + 1..=ei {
936 sum += h[[row, col]] * v[col];
937 }
938 let diag = h[[row, row]] - lambda;
939 v[row] = if diag.norm() > 1e-14 {
940 -sum / diag
941 } else {
942 Complex::new(0.0, 0.0)
943 };
944 }
945
946 let norm = v.iter().map(|vi| vi.norm_sqr()).sum::<f64>().sqrt();
947 if norm > 1e-300 {
948 for vi in &mut v {
949 *vi /= norm;
950 }
951 } else {
952 v[ei] = Complex::new(1.0, 0.0);
953 }
954
955 for row in 0..n {
956 vecs[[row, ei]] = v[row];
957 }
958 }
959
960 let eigenvectors = q.dot(&vecs);
962 Ok((eigenvalues, eigenvectors))
963 }
964 }
965
966 pub trait Cholesky {
968 type Elem;
969
970 fn cholesky(&self, uplo: UPLO) -> Result<Array2<Self::Elem>, LapackError>;
971 }
972
973 impl Cholesky for Array2<f64> {
974 type Elem = f64;
975
976 #[inline]
977 fn cholesky(&self, _uplo: UPLO) -> Result<Array2<f64>, LapackError> {
978 let result = cholesky_ndarray(self)?;
979 Ok(result.l)
980 }
981 }
982
983 impl Cholesky for Array2<Complex<f64>> {
984 type Elem = Complex<f64>;
985
986 #[inline]
987 fn cholesky(&self, _uplo: UPLO) -> Result<Array2<Complex<f64>>, LapackError> {
988 let result = cholesky_hermitian_ndarray(self)?;
989 Ok(result.l)
990 }
991 }
992}