1pub use ::ndarray::*;
33
34#[cfg(feature = "random")]
41pub use ndarray_rand::{rand_distr as distributions, RandomExt, SamplingStrategy};
42
43#[cfg(feature = "array_stats")]
48pub use ndarray_stats::{
49 errors as stats_errors, interpolate, CorrelationExt, DeviationExt, MaybeNan, QuantileExt,
50 Sort1dExt, SummaryStatisticsExt,
51};
52
53#[cfg(feature = "array_io")]
54pub use ndarray_npy::{
55 NpzReader, NpzWriter, ReadNpyExt, ReadNpzError, ViewMutNpyExt, ViewNpyExt, WriteNpyError,
56 WriteNpyExt,
57};
58
59pub mod utils {
65 use super::*;
66
67 pub fn eye<A>(n: usize) -> Array2<A>
69 where
70 A: Clone + num_traits::Zero + num_traits::One,
71 {
72 let mut arr = Array2::zeros((n, n));
73 for i in 0..n {
74 arr[[i, i]] = A::one();
75 }
76 arr
77 }
78
79 pub fn diag<A>(v: &Array1<A>) -> Array2<A>
81 where
82 A: Clone + num_traits::Zero,
83 {
84 let n = v.len();
85 let mut arr = Array2::zeros((n, n));
86 for i in 0..n {
87 arr[[i, i]] = v[i].clone();
88 }
89 arr
90 }
91
92 pub fn allclose<A, D>(
94 a: &ArrayBase<impl Data<Elem = A>, D>,
95 b: &ArrayBase<impl Data<Elem = A>, D>,
96 rtol: A,
97 atol: A,
98 ) -> bool
99 where
100 A: PartialOrd
101 + std::ops::Sub<Output = A>
102 + std::ops::Mul<Output = A>
103 + std::ops::Add<Output = A>
104 + Clone,
105 D: Dimension,
106 {
107 if a.shape() != b.shape() {
108 return false;
109 }
110
111 a.iter().zip(b.iter()).all(|(a_val, b_val)| {
112 let diff = if a_val > b_val {
113 a_val.clone() - b_val.clone()
114 } else {
115 b_val.clone() - a_val.clone()
116 };
117
118 let threshold = atol.clone()
119 + rtol.clone()
120 * (if a_val > b_val {
121 a_val.clone()
122 } else {
123 b_val.clone()
124 });
125
126 diff <= threshold
127 })
128 }
129
130 pub fn concatenate<A, D>(
132 axis: Axis,
133 arrays: &[ArrayView<A, D>],
134 ) -> Result<Array<A, D>, ShapeError>
135 where
136 A: Clone,
137 D: Dimension + RemoveAxis,
138 {
139 ndarray::concatenate(axis, arrays)
140 }
141
142 pub fn stack<A, D>(
144 axis: Axis,
145 arrays: &[ArrayView<A, D>],
146 ) -> Result<Array<A, D::Larger>, ShapeError>
147 where
148 A: Clone,
149 D: Dimension,
150 D::Larger: RemoveAxis,
151 {
152 ndarray::stack(axis, arrays)
153 }
154}
155
156pub mod compat {
163 pub use super::*;
164 use crate::numeric::{Float, FromPrimitive};
165
166 pub type DynArray<T> = ArrayD<T>;
168 pub type Matrix<T> = Array2<T>;
169 pub type Vector<T> = Array1<T>;
170 pub type Tensor3<T> = Array3<T>;
171 pub type Tensor4<T> = Array4<T>;
172
173 pub trait ArrayStatCompat<T> {
196 fn mean_or(&self, default: T) -> T;
202
203 fn var_or(&self, ddof: T, default: T) -> T;
205
206 fn std_or(&self, ddof: T, default: T) -> T;
208 }
209
210 impl<T, S, D> ArrayStatCompat<T> for ArrayBase<S, D>
211 where
212 T: Float + FromPrimitive,
213 S: Data<Elem = T>,
214 D: Dimension,
215 {
216 fn mean_or(&self, default: T) -> T {
217 self.mean().unwrap_or(default)
219 }
220
221 fn var_or(&self, ddof: T, default: T) -> T {
222 let v = self.var(ddof);
224 if v.is_nan() {
225 default
226 } else {
227 v
228 }
229 }
230
231 fn std_or(&self, ddof: T, default: T) -> T {
232 let s = self.std(ddof);
234 if s.is_nan() {
235 default
236 } else {
237 s
238 }
239 }
240 }
241
242 pub use crate::ndarray_ext::{
244 broadcast_1d_to_2d,
245 broadcast_apply,
246 fancy_index_2d,
247 indexing,
249 is_broadcast_compatible,
250 manipulation,
251 mask_select,
252 matrix,
253 reshape_2d,
254 split_2d,
255 stack_2d,
256 stats,
257 take_2d,
258 transpose_2d,
259 where_condition,
260 };
261}
262
263pub mod prelude {
269 pub use super::{
270 arr1,
271 arr2,
272 array,
274 azip,
275 concatenate,
277 s,
278 stack,
279
280 stack as stack_fn,
281 Array,
283 Array0,
284 Array1,
285 Array2,
286 Array3,
287 ArrayD,
288 ArrayView,
289 ArrayView1,
290 ArrayView2,
291 ArrayViewMut,
292
293 Axis,
295 Dimension,
297 Ix1,
298 Ix2,
299 Ix3,
300 IxDyn,
301 ScalarOperand,
302 ShapeBuilder,
303
304 Zip,
305 };
306
307 #[cfg(feature = "random")]
308 pub use super::RandomExt;
309
310 pub type Matrix<T> = super::Array2<T>;
312 pub type Vector<T> = super::Array1<T>;
313}
314
315#[cfg(test)]
320pub mod examples {
321 use super::*;
324
325 #[test]
351 fn test_complete_functionality() {
352 let a = array![[1., 2.], [3., 4.]];
354 assert_eq!(a.shape(), &[2, 2]);
355
356 let slice = a.slice(s![.., 0]);
358 assert_eq!(slice.len(), 2);
359
360 let b = &a + &a;
362 assert_eq!(b[[0, 0]], 2.);
363
364 let sum = a.sum_axis(Axis(0));
366 assert_eq!(sum.len(), 2);
367
368 let c = array![1., 2.];
370 let d = &a + &c;
371 assert_eq!(d[[0, 0]], 2.);
372 }
373}
374
375pub mod migration_guide {
380 }
420
421pub use compat::ArrayStatCompat;
423
424#[cfg(test)]
425mod tests {
426 use super::*;
427
428 #[test]
429 fn test_array_macro_available() {
430 let arr = array![[1, 2], [3, 4]];
431 assert_eq!(arr.shape(), &[2, 2]);
432 assert_eq!(arr[[0, 0]], 1);
433 }
434
435 #[test]
436 fn test_s_macro_available() {
437 let arr = array![[1, 2, 3], [4, 5, 6]];
438 let slice = arr.slice(s![.., 1..]);
439 assert_eq!(slice.shape(), &[2, 2]);
440 }
441
442 #[test]
443 fn test_axis_operations() {
444 let arr = array![[1., 2.], [3., 4.]];
445 let sum = arr.sum_axis(Axis(0));
446 assert_eq!(sum, array![4., 6.]);
447 }
448
449 #[test]
450 fn test_views_and_iteration() {
451 let mut arr = array![[1, 2], [3, 4]];
452
453 {
455 let view: ArrayView<_, _> = arr.view();
456 for val in view.iter() {
457 assert!(*val > 0);
458 }
459 }
460
461 {
463 let mut view_mut: ArrayViewMut<_, _> = arr.view_mut();
464 for val in view_mut.iter_mut() {
465 *val *= 2;
466 }
467 }
468
469 assert_eq!(arr[[0, 0]], 2);
470 }
471
472 #[test]
473 fn test_concatenate_and_stack() {
474 let a = array![[1, 2], [3, 4]];
475 let b = array![[5, 6], [7, 8]];
476
477 let concat = concatenate(Axis(0), &[a.view(), b.view()]).expect("Operation failed");
479 assert_eq!(concat.shape(), &[4, 2]);
480
481 let stacked =
483 crate::ndarray::stack(Axis(0), &[a.view(), b.view()]).expect("Operation failed");
484 assert_eq!(stacked.shape(), &[2, 2, 2]);
485 }
486
487 #[test]
488 fn test_zip_operations() {
489 let a = array![1, 2, 3];
490 let b = array![4, 5, 6];
491 let mut c = array![0, 0, 0];
492
493 azip!((a in &a, b in &b, c in &mut c) {
494 *c = a + b;
495 });
496
497 assert_eq!(c, array![5, 7, 9]);
498 }
499
500 #[test]
501 fn test_array_stat_compat() {
502 use compat::ArrayStatCompat;
503
504 let data = array![1.0, 2.0, 3.0, 4.0, 5.0];
506 assert_eq!(data.mean_or(0.0), 3.0);
507
508 let empty = Array1::<f64>::from(vec![]);
509 assert_eq!(empty.mean_or(0.0), 0.0);
510
511 let data = array![1.0, 2.0, 3.0, 4.0, 5.0];
513 let var = data.var_or(1.0, 0.0);
514 assert!(var > 0.0);
515
516 let std = data.std_or(1.0, 0.0);
518 assert!(std > 0.0);
519 }
520}
521
522#[cfg(feature = "linalg")]
530pub mod ndarray_linalg {
531 use crate::linalg::prelude::*;
532 use crate::ndarray::*;
533 use num_complex::Complex;
534
535 use oxiblas_ndarray::lapack::{
537 cholesky_hermitian_ndarray, eig_hermitian_ndarray, qr_complex_ndarray, svd_complex_ndarray,
538 };
539
540 pub use crate::linalg::{LapackError, LapackResult};
542
543 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
545 pub enum UPLO {
546 Upper,
547 Lower,
548 }
549
550 pub trait Solve<A> {
552 fn solve_into(&self, b: &Array1<A>) -> Result<Array1<A>, LapackError>;
553 }
554
555 impl Solve<f64> for Array2<f64> {
556 #[inline]
557 fn solve_into(&self, b: &Array1<f64>) -> Result<Array1<f64>, LapackError> {
558 solve_ndarray(self, b)
559 }
560 }
561
562 impl Solve<Complex<f64>> for Array2<Complex<f64>> {
563 #[inline]
564 fn solve_into(
565 &self,
566 b: &Array1<Complex<f64>>,
567 ) -> Result<Array1<Complex<f64>>, LapackError> {
568 solve_ndarray(self, b)
569 }
570 }
571
572 pub trait SVD {
574 type Elem;
575 type Real;
576
577 fn svd(
578 &self,
579 compute_u: bool,
580 compute_vt: bool,
581 ) -> Result<(Array2<Self::Elem>, Array1<Self::Real>, Array2<Self::Elem>), LapackError>;
582 }
583
584 impl SVD for Array2<f64> {
585 type Elem = f64;
586 type Real = f64;
587
588 #[inline]
589 fn svd(
590 &self,
591 _compute_u: bool,
592 _compute_vt: bool,
593 ) -> Result<(Array2<f64>, Array1<f64>, Array2<f64>), LapackError> {
594 let result = svd_ndarray(self)?;
595 Ok((result.u, result.s, result.vt))
596 }
597 }
598
599 impl SVD for Array2<Complex<f64>> {
600 type Elem = Complex<f64>;
601 type Real = f64;
602
603 #[inline]
604 fn svd(
605 &self,
606 _compute_u: bool,
607 _compute_vt: bool,
608 ) -> Result<(Array2<Complex<f64>>, Array1<f64>, Array2<Complex<f64>>), LapackError>
609 {
610 let result = svd_complex_ndarray(self)?;
611 Ok((result.u, result.s, result.vt))
612 }
613 }
614
615 pub trait Eigh {
617 type Elem;
618 type Real;
619
620 fn eigh(&self, uplo: UPLO)
621 -> Result<(Array1<Self::Real>, Array2<Self::Elem>), LapackError>;
622 }
623
624 impl Eigh for Array2<f64> {
625 type Elem = f64;
626 type Real = f64;
627
628 #[inline]
629 fn eigh(&self, _uplo: UPLO) -> Result<(Array1<f64>, Array2<f64>), LapackError> {
630 let result = eig_symmetric(self)?;
631 Ok((result.eigenvalues, result.eigenvectors))
632 }
633 }
634
635 impl Eigh for Array2<Complex<f64>> {
636 type Elem = Complex<f64>;
637 type Real = f64;
638
639 #[inline]
640 fn eigh(&self, _uplo: UPLO) -> Result<(Array1<f64>, Array2<Complex<f64>>), LapackError> {
641 eig_hermitian_ndarray(self)
642 }
643 }
644
645 pub trait Norm {
647 type Real;
648
649 fn norm_l2(&self) -> Result<Self::Real, LapackError>;
650 }
651
652 impl Norm for Array2<f64> {
653 type Real = f64;
654
655 #[inline]
656 fn norm_l2(&self) -> Result<f64, LapackError> {
657 let sum_sq: f64 = self.iter().map(|x| x * x).sum();
658 Ok(sum_sq.sqrt())
659 }
660 }
661
662 impl Norm for Array2<Complex<f64>> {
663 type Real = f64;
664
665 #[inline]
666 fn norm_l2(&self) -> Result<f64, LapackError> {
667 let sum_sq: f64 = self.iter().map(|x| x.norm_sqr()).sum();
668 Ok(sum_sq.sqrt())
669 }
670 }
671
672 impl Norm for Array1<f64> {
674 type Real = f64;
675
676 #[inline]
677 fn norm_l2(&self) -> Result<f64, LapackError> {
678 let sum_sq: f64 = self.iter().map(|x| x * x).sum();
679 Ok(sum_sq.sqrt())
680 }
681 }
682
683 impl Norm for Array1<Complex<f64>> {
684 type Real = f64;
685
686 #[inline]
687 fn norm_l2(&self) -> Result<f64, LapackError> {
688 let sum_sq: f64 = self.iter().map(|x| x.norm_sqr()).sum();
689 Ok(sum_sq.sqrt())
690 }
691 }
692
693 pub trait QR {
695 type Elem;
696
697 fn qr(&self) -> Result<(Array2<Self::Elem>, Array2<Self::Elem>), LapackError>;
698 }
699
700 impl QR for Array2<f64> {
701 type Elem = f64;
702
703 #[inline]
704 fn qr(&self) -> Result<(Array2<f64>, Array2<f64>), LapackError> {
705 let result = qr_ndarray(self)?;
706 Ok((result.q, result.r))
707 }
708 }
709
710 impl QR for Array2<Complex<f64>> {
711 type Elem = Complex<f64>;
712
713 #[inline]
714 fn qr(&self) -> Result<(Array2<Complex<f64>>, Array2<Complex<f64>>), LapackError> {
715 let result = qr_complex_ndarray(self)?;
716 Ok((result.q, result.r))
717 }
718 }
719
720 pub trait Eig {
722 type Elem;
723
724 fn eig(&self) -> Result<(Array1<Self::Elem>, Array2<Self::Elem>), LapackError>;
725 }
726
727 impl Eig for Array2<Complex<f64>> {
735 type Elem = Complex<f64>;
736
737 fn eig(&self) -> Result<(Array1<Complex<f64>>, Array2<Complex<f64>>), LapackError> {
738 let (m, n) = self.dim();
739 if m != n {
740 return Err(LapackError::DimensionMismatch(
741 "Matrix must be square for eigendecomposition".to_string(),
742 ));
743 }
744 if n == 0 {
745 return Ok((
746 Array1::<Complex<f64>>::zeros(0),
747 Array2::<Complex<f64>>::zeros((0, 0)),
748 ));
749 }
750 if n == 1 {
751 let eigenvalue = self[[0, 0]];
752 let eigenvector = Array2::from_elem((1, 1), Complex::new(1.0, 0.0));
753 return Ok((Array1::from_vec(vec![eigenvalue]), eigenvector));
754 }
755
756 let mut h = self.clone();
758 let mut q = Array2::<Complex<f64>>::eye(n);
759
760 for col in 0..n.saturating_sub(2) {
761 let xlen = n - col - 1;
762 if xlen == 0 {
763 continue;
764 }
765
766 let mut x: Vec<Complex<f64>> = (col + 1..n).map(|r| h[[r, col]]).collect();
767
768 let norm_x = x.iter().map(|v| v.norm_sqr()).sum::<f64>().sqrt();
769 if norm_x < 1e-300 {
770 continue;
771 }
772
773 let phase = if x[0].norm() > 1e-300 {
774 x[0] / x[0].norm()
775 } else {
776 Complex::new(1.0, 0.0)
777 };
778 x[0] += phase * norm_x;
779
780 let norm_v = x.iter().map(|v| v.norm_sqr()).sum::<f64>().sqrt();
781 if norm_v < 1e-300 {
782 continue;
783 }
784 let v: Vec<Complex<f64>> = x.iter().map(|vi| *vi / norm_v).collect();
785
786 for c in col..n {
788 let dot: Complex<f64> = v
789 .iter()
790 .enumerate()
791 .map(|(i, &vi)| vi.conj() * h[[col + 1 + i, c]])
792 .sum();
793 for (i, &vi) in v.iter().enumerate() {
794 h[[col + 1 + i, c]] -= Complex::new(2.0, 0.0) * vi * dot;
795 }
796 }
797
798 for r in 0..n {
800 let dot: Complex<f64> = v
801 .iter()
802 .enumerate()
803 .map(|(i, &vi)| h[[r, col + 1 + i]] * vi)
804 .sum();
805 for (i, &vi) in v.iter().enumerate() {
806 h[[r, col + 1 + i]] -= Complex::new(2.0, 0.0) * dot * vi.conj();
807 }
808 }
809
810 for r in 0..n {
812 let dot: Complex<f64> = v
813 .iter()
814 .enumerate()
815 .map(|(i, &vi)| q[[r, col + 1 + i]] * vi)
816 .sum();
817 for (i, &vi) in v.iter().enumerate() {
818 q[[r, col + 1 + i]] -= Complex::new(2.0, 0.0) * dot * vi.conj();
819 }
820 }
821
822 for r in col + 2..n {
824 h[[r, col]] = Complex::new(0.0, 0.0);
825 }
826 }
827
828 const MAX_ITER: usize = 30;
830 let mut p = n;
831
832 'outer: while p > 1 {
833 let mut deflated = false;
835 for l in (1..p).rev() {
836 let sub = h[[l, l - 1]].norm();
837 let diag = h[[l - 1, l - 1]].norm() + h[[l, l]].norm();
838 if sub <= 1e-14 * diag || sub <= f64::MIN_POSITIVE.sqrt() {
839 h[[l, l - 1]] = Complex::new(0.0, 0.0);
840 if l == p - 1 {
841 p -= 1;
842 deflated = true;
843 break;
844 }
845 }
846 }
847 if deflated {
848 continue 'outer;
849 }
850
851 let mut converged_inner = false;
852 for _iter in 0..MAX_ITER {
853 let a_sub = h[[p - 2, p - 2]];
855 let b_sub = h[[p - 2, p - 1]];
856 let c_sub = h[[p - 1, p - 2]];
857 let d_sub = h[[p - 1, p - 1]];
858 let tr = a_sub + d_sub;
859 let det = a_sub * d_sub - b_sub * c_sub;
860 let disc = (tr * tr - Complex::new(4.0, 0.0) * det).sqrt();
861 let mu1 = (tr + disc) * Complex::new(0.5, 0.0);
862 let mu2 = (tr - disc) * Complex::new(0.5, 0.0);
863 let shift = if (mu1 - d_sub).norm() < (mu2 - d_sub).norm() {
864 mu1
865 } else {
866 mu2
867 };
868
869 for k in 0..p.saturating_sub(1) {
871 let a_g = if k == 0 {
872 h[[0, 0]] - shift
873 } else {
874 h[[k, k - 1]]
875 };
876 let b_g = h[[k + 1, k]];
877 let r = (a_g.norm_sqr() + b_g.norm_sqr()).sqrt();
878 if r < 1e-300 {
879 continue;
880 }
881 let c = a_g / r;
882 let s = b_g / r;
883
884 let col_start = if k == 0 { 0 } else { k - 1 };
886 for j in col_start..n {
887 let t1 = c.conj() * h[[k, j]] + s.conj() * h[[k + 1, j]];
888 let t2 = -s * h[[k, j]] + c * h[[k + 1, j]];
889 h[[k, j]] = t1;
890 h[[k + 1, j]] = t2;
891 }
892
893 let row_max = (k + 2).min(p);
895 for i in 0..row_max {
896 let t1 = h[[i, k]] * c + h[[i, k + 1]] * s;
897 let t2 = h[[i, k]] * (-s.conj()) + h[[i, k + 1]] * c.conj();
898 h[[i, k]] = t1;
899 h[[i, k + 1]] = t2;
900 }
901
902 for i in 0..n {
904 let t1 = q[[i, k]] * c + q[[i, k + 1]] * s;
905 let t2 = q[[i, k]] * (-s.conj()) + q[[i, k + 1]] * c.conj();
906 q[[i, k]] = t1;
907 q[[i, k + 1]] = t2;
908 }
909 }
910
911 let sub_norm = h[[p - 1, p - 2]].norm();
912 let diag_norm = h[[p - 2, p - 2]].norm() + h[[p - 1, p - 1]].norm();
913 if sub_norm <= 1e-14 * diag_norm || sub_norm <= f64::MIN_POSITIVE.sqrt() {
914 h[[p - 1, p - 2]] = Complex::new(0.0, 0.0);
915 p -= 1;
916 converged_inner = true;
917 break;
918 }
919 }
920
921 if !converged_inner {
922 p -= 1; }
924 }
925
926 let eigenvalues: Array1<Complex<f64>> = Array1::from_iter((0..n).map(|i| h[[i, i]]));
928
929 let mut vecs = Array2::<Complex<f64>>::zeros((n, n));
932 for ei in 0..n {
933 let lambda = eigenvalues[ei];
934 let mut v = vec![Complex::new(0.0, 0.0); n];
935 v[ei] = Complex::new(1.0, 0.0);
936
937 for row in (0..ei).rev() {
938 let mut sum = Complex::new(0.0, 0.0);
939 for col in row + 1..=ei {
940 sum += h[[row, col]] * v[col];
941 }
942 let diag = h[[row, row]] - lambda;
943 v[row] = if diag.norm() > 1e-14 {
944 -sum / diag
945 } else {
946 Complex::new(0.0, 0.0)
947 };
948 }
949
950 let norm = v.iter().map(|vi| vi.norm_sqr()).sum::<f64>().sqrt();
951 if norm > 1e-300 {
952 for vi in &mut v {
953 *vi /= norm;
954 }
955 } else {
956 v[ei] = Complex::new(1.0, 0.0);
957 }
958
959 for row in 0..n {
960 vecs[[row, ei]] = v[row];
961 }
962 }
963
964 let eigenvectors = q.dot(&vecs);
966 Ok((eigenvalues, eigenvectors))
967 }
968 }
969
970 pub trait Cholesky {
972 type Elem;
973
974 fn cholesky(&self, uplo: UPLO) -> Result<Array2<Self::Elem>, LapackError>;
975 }
976
977 impl Cholesky for Array2<f64> {
978 type Elem = f64;
979
980 #[inline]
981 fn cholesky(&self, _uplo: UPLO) -> Result<Array2<f64>, LapackError> {
982 let result = cholesky_ndarray(self)?;
983 Ok(result.l)
984 }
985 }
986
987 impl Cholesky for Array2<Complex<f64>> {
988 type Elem = Complex<f64>;
989
990 #[inline]
991 fn cholesky(&self, _uplo: UPLO) -> Result<Array2<Complex<f64>>, LapackError> {
992 let result = cholesky_hermitian_ndarray(self)?;
993 Ok(result.l)
994 }
995 }
996}