1pub use ::ndarray::*;
33
34#[cfg(feature = "random")]
41pub use ndarray_rand::{rand_distr as distributions, RandomExt, SamplingStrategy};
42
43#[cfg(feature = "array_stats")]
48pub use ndarray_stats::{
49 errors as stats_errors, interpolate, CorrelationExt, DeviationExt, MaybeNan, QuantileExt,
50 Sort1dExt, SummaryStatisticsExt,
51};
52
53#[cfg(feature = "array_io")]
54pub use ndarray_npy::{
55 NpzReader, NpzWriter, ReadNpyExt, ReadNpzError, ViewMutNpyExt, ViewNpyExt, WriteNpyError,
56 WriteNpyExt,
57};
58
59pub mod utils {
65 use super::*;
66
67 pub fn eye<A>(n: usize) -> Array2<A>
69 where
70 A: Clone + num_traits::Zero + num_traits::One,
71 {
72 let mut arr = Array2::zeros((n, n));
73 for i in 0..n {
74 arr[[i, i]] = A::one();
75 }
76 arr
77 }
78
79 pub fn diag<A>(v: &Array1<A>) -> Array2<A>
81 where
82 A: Clone + num_traits::Zero,
83 {
84 let n = v.len();
85 let mut arr = Array2::zeros((n, n));
86 for i in 0..n {
87 arr[[i, i]] = v[i].clone();
88 }
89 arr
90 }
91
92 pub fn allclose<A, D>(
94 a: &ArrayBase<impl Data<Elem = A>, D>,
95 b: &ArrayBase<impl Data<Elem = A>, D>,
96 rtol: A,
97 atol: A,
98 ) -> bool
99 where
100 A: PartialOrd
101 + std::ops::Sub<Output = A>
102 + std::ops::Mul<Output = A>
103 + std::ops::Add<Output = A>
104 + Clone,
105 D: Dimension,
106 {
107 if a.shape() != b.shape() {
108 return false;
109 }
110
111 a.iter().zip(b.iter()).all(|(a_val, b_val)| {
112 let diff = if a_val > b_val {
113 a_val.clone() - b_val.clone()
114 } else {
115 b_val.clone() - a_val.clone()
116 };
117
118 let threshold = atol.clone()
119 + rtol.clone()
120 * (if a_val > b_val {
121 a_val.clone()
122 } else {
123 b_val.clone()
124 });
125
126 diff <= threshold
127 })
128 }
129
130 pub fn concatenate<A, D>(
132 axis: Axis,
133 arrays: &[ArrayView<A, D>],
134 ) -> Result<Array<A, D>, ShapeError>
135 where
136 A: Clone,
137 D: Dimension + RemoveAxis,
138 {
139 ndarray::concatenate(axis, arrays)
140 }
141
142 pub fn stack<A, D>(
144 axis: Axis,
145 arrays: &[ArrayView<A, D>],
146 ) -> Result<Array<A, D::Larger>, ShapeError>
147 where
148 A: Clone,
149 D: Dimension,
150 D::Larger: RemoveAxis,
151 {
152 ndarray::stack(axis, arrays)
153 }
154}
155
156pub mod compat {
163 pub use super::*;
164 use crate::numeric::{Float, FromPrimitive};
165
166 pub type DynArray<T> = ArrayD<T>;
168 pub type Matrix<T> = Array2<T>;
169 pub type Vector<T> = Array1<T>;
170 pub type Tensor3<T> = Array3<T>;
171 pub type Tensor4<T> = Array4<T>;
172
173 pub trait ArrayStatCompat<T> {
196 fn mean_or(&self, default: T) -> T;
202
203 fn var_or(&self, ddof: T, default: T) -> T;
205
206 fn std_or(&self, ddof: T, default: T) -> T;
208 }
209
210 impl<T, S, D> ArrayStatCompat<T> for ArrayBase<S, D>
211 where
212 T: Float + FromPrimitive,
213 S: Data<Elem = T>,
214 D: Dimension,
215 {
216 fn mean_or(&self, default: T) -> T {
217 self.mean().unwrap_or(default)
219 }
220
221 fn var_or(&self, ddof: T, default: T) -> T {
222 let v = self.var(ddof);
224 if v.is_nan() {
225 default
226 } else {
227 v
228 }
229 }
230
231 fn std_or(&self, ddof: T, default: T) -> T {
232 let s = self.std(ddof);
234 if s.is_nan() {
235 default
236 } else {
237 s
238 }
239 }
240 }
241
242 pub use crate::ndarray_ext::{
244 broadcast_1d_to_2d,
245 broadcast_apply,
246 fancy_index_2d,
247 indexing,
249 is_broadcast_compatible,
250 manipulation,
251 mask_select,
252 matrix,
253 reshape_2d,
254 split_2d,
255 stack_2d,
256 stats,
257 take_2d,
258 transpose_2d,
259 where_condition,
260 };
261}
262
263pub mod prelude {
269 pub use super::{
270 arr1,
271 arr2,
272 array,
274 azip,
275 concatenate,
277 s,
278 stack,
279
280 stack as stack_fn,
281 Array,
283 Array0,
284 Array1,
285 Array2,
286 Array3,
287 ArrayD,
288 ArrayView,
289 ArrayView1,
290 ArrayView2,
291 ArrayViewMut,
292
293 Axis,
295 Dimension,
297 Ix1,
298 Ix2,
299 Ix3,
300 IxDyn,
301 ScalarOperand,
302 ShapeBuilder,
303
304 Zip,
305 };
306
307 #[cfg(feature = "random")]
308 pub use super::RandomExt;
309
310 pub type Matrix<T> = super::Array2<T>;
312 pub type Vector<T> = super::Array1<T>;
313}
314
315#[cfg(test)]
320pub mod examples {
321 use super::*;
324
325 #[test]
351 fn test_complete_functionality() {
352 let a = array![[1., 2.], [3., 4.]];
354 assert_eq!(a.shape(), &[2, 2]);
355
356 let slice = a.slice(s![.., 0]);
358 assert_eq!(slice.len(), 2);
359
360 let b = &a + &a;
362 assert_eq!(b[[0, 0]], 2.);
363
364 let sum = a.sum_axis(Axis(0));
366 assert_eq!(sum.len(), 2);
367
368 let c = array![1., 2.];
370 let d = &a + &c;
371 assert_eq!(d[[0, 0]], 2.);
372 }
373}
374
375pub mod migration_guide {
380 }
420
421pub use compat::ArrayStatCompat;
423
424#[cfg(test)]
425mod tests {
426 use super::*;
427
428 #[test]
429 fn test_array_macro_available() {
430 let arr = array![[1, 2], [3, 4]];
431 assert_eq!(arr.shape(), &[2, 2]);
432 assert_eq!(arr[[0, 0]], 1);
433 }
434
435 #[test]
436 fn test_s_macro_available() {
437 let arr = array![[1, 2, 3], [4, 5, 6]];
438 let slice = arr.slice(s![.., 1..]);
439 assert_eq!(slice.shape(), &[2, 2]);
440 }
441
442 #[test]
443 fn test_axis_operations() {
444 let arr = array![[1., 2.], [3., 4.]];
445 let sum = arr.sum_axis(Axis(0));
446 assert_eq!(sum, array![4., 6.]);
447 }
448
449 #[test]
450 fn test_views_and_iteration() {
451 let mut arr = array![[1, 2], [3, 4]];
452
453 {
455 let view: ArrayView<_, _> = arr.view();
456 for val in view.iter() {
457 assert!(*val > 0);
458 }
459 }
460
461 {
463 let mut view_mut: ArrayViewMut<_, _> = arr.view_mut();
464 for val in view_mut.iter_mut() {
465 *val *= 2;
466 }
467 }
468
469 assert_eq!(arr[[0, 0]], 2);
470 }
471
472 #[test]
473 fn test_concatenate_and_stack() {
474 let a = array![[1, 2], [3, 4]];
475 let b = array![[5, 6], [7, 8]];
476
477 let concat = concatenate(Axis(0), &[a.view(), b.view()]).expect("Operation failed");
479 assert_eq!(concat.shape(), &[4, 2]);
480
481 let stacked =
483 crate::ndarray::stack(Axis(0), &[a.view(), b.view()]).expect("Operation failed");
484 assert_eq!(stacked.shape(), &[2, 2, 2]);
485 }
486
487 #[test]
488 fn test_zip_operations() {
489 let a = array![1, 2, 3];
490 let b = array![4, 5, 6];
491 let mut c = array![0, 0, 0];
492
493 azip!((a in &a, b in &b, c in &mut c) {
494 *c = a + b;
495 });
496
497 assert_eq!(c, array![5, 7, 9]);
498 }
499
500 #[test]
501 fn test_array_stat_compat() {
502 use compat::ArrayStatCompat;
503
504 let data = array![1.0, 2.0, 3.0, 4.0, 5.0];
506 assert_eq!(data.mean_or(0.0), 3.0);
507
508 let empty = Array1::<f64>::from(vec![]);
509 assert_eq!(empty.mean_or(0.0), 0.0);
510
511 let data = array![1.0, 2.0, 3.0, 4.0, 5.0];
513 let var = data.var_or(1.0, 0.0);
514 assert!(var > 0.0);
515
516 let std = data.std_or(1.0, 0.0);
518 assert!(std > 0.0);
519 }
520}
521
522#[cfg(feature = "linalg")]
530pub mod ndarray_linalg {
531 use crate::linalg::prelude::*;
532 use crate::ndarray::*;
533 use num_complex::Complex;
534
535 use oxiblas_ndarray::lapack::{
537 cholesky_hermitian_ndarray, eig_hermitian_ndarray, qr_complex_ndarray, svd_complex_ndarray,
538 };
539
540 pub use crate::linalg::{LapackError, LapackResult};
542
543 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
545 pub enum UPLO {
546 Upper,
547 Lower,
548 }
549
550 pub trait Solve<A> {
552 fn solve_into(&self, b: &Array1<A>) -> Result<Array1<A>, LapackError>;
553 }
554
555 impl Solve<f64> for Array2<f64> {
556 #[inline]
557 fn solve_into(&self, b: &Array1<f64>) -> Result<Array1<f64>, LapackError> {
558 solve_ndarray(self, b)
559 }
560 }
561
562 impl Solve<Complex<f64>> for Array2<Complex<f64>> {
563 #[inline]
564 fn solve_into(
565 &self,
566 b: &Array1<Complex<f64>>,
567 ) -> Result<Array1<Complex<f64>>, LapackError> {
568 solve_ndarray(self, b)
569 }
570 }
571
572 pub trait SVD {
574 type Elem;
575 type Real;
576
577 fn svd(
578 &self,
579 compute_u: bool,
580 compute_vt: bool,
581 ) -> Result<(Array2<Self::Elem>, Array1<Self::Real>, Array2<Self::Elem>), LapackError>;
582 }
583
584 impl SVD for Array2<f64> {
585 type Elem = f64;
586 type Real = f64;
587
588 #[inline]
589 fn svd(
590 &self,
591 _compute_u: bool,
592 _compute_vt: bool,
593 ) -> Result<(Array2<f64>, Array1<f64>, Array2<f64>), LapackError> {
594 let result = svd_ndarray(self)?;
595 Ok((result.u, result.s, result.vt))
596 }
597 }
598
599 impl SVD for Array2<Complex<f64>> {
600 type Elem = Complex<f64>;
601 type Real = f64;
602
603 #[inline]
604 fn svd(
605 &self,
606 _compute_u: bool,
607 _compute_vt: bool,
608 ) -> Result<(Array2<Complex<f64>>, Array1<f64>, Array2<Complex<f64>>), LapackError>
609 {
610 let result = svd_complex_ndarray(self)?;
611 Ok((result.u, result.s, result.vt))
612 }
613 }
614
615 pub trait Eigh {
617 type Elem;
618 type Real;
619
620 fn eigh(&self, uplo: UPLO)
621 -> Result<(Array1<Self::Real>, Array2<Self::Elem>), LapackError>;
622 }
623
624 impl Eigh for Array2<f64> {
625 type Elem = f64;
626 type Real = f64;
627
628 #[inline]
629 fn eigh(&self, _uplo: UPLO) -> Result<(Array1<f64>, Array2<f64>), LapackError> {
630 let result = eig_symmetric(self)?;
631 Ok((result.eigenvalues, result.eigenvectors))
632 }
633 }
634
635 impl Eigh for Array2<Complex<f64>> {
636 type Elem = Complex<f64>;
637 type Real = f64;
638
639 #[inline]
640 fn eigh(&self, _uplo: UPLO) -> Result<(Array1<f64>, Array2<Complex<f64>>), LapackError> {
641 eig_hermitian_ndarray(self)
642 }
643 }
644
645 pub trait Norm {
647 type Real;
648
649 fn norm_l2(&self) -> Result<Self::Real, LapackError>;
650 }
651
652 impl Norm for Array2<f64> {
653 type Real = f64;
654
655 #[inline]
656 fn norm_l2(&self) -> Result<f64, LapackError> {
657 let sum_sq: f64 = self.iter().map(|x| x * x).sum();
658 Ok(sum_sq.sqrt())
659 }
660 }
661
662 impl Norm for Array2<Complex<f64>> {
663 type Real = f64;
664
665 #[inline]
666 fn norm_l2(&self) -> Result<f64, LapackError> {
667 let sum_sq: f64 = self.iter().map(|x| x.norm_sqr()).sum();
668 Ok(sum_sq.sqrt())
669 }
670 }
671
672 impl Norm for Array1<f64> {
674 type Real = f64;
675
676 #[inline]
677 fn norm_l2(&self) -> Result<f64, LapackError> {
678 let sum_sq: f64 = self.iter().map(|x| x * x).sum();
679 Ok(sum_sq.sqrt())
680 }
681 }
682
683 impl Norm for Array1<Complex<f64>> {
684 type Real = f64;
685
686 #[inline]
687 fn norm_l2(&self) -> Result<f64, LapackError> {
688 let sum_sq: f64 = self.iter().map(|x| x.norm_sqr()).sum();
689 Ok(sum_sq.sqrt())
690 }
691 }
692
693 pub trait QR {
695 type Elem;
696
697 fn qr(&self) -> Result<(Array2<Self::Elem>, Array2<Self::Elem>), LapackError>;
698 }
699
700 impl QR for Array2<f64> {
701 type Elem = f64;
702
703 #[inline]
704 fn qr(&self) -> Result<(Array2<f64>, Array2<f64>), LapackError> {
705 let result = qr_ndarray(self)?;
706 Ok((result.q, result.r))
707 }
708 }
709
710 impl QR for Array2<Complex<f64>> {
711 type Elem = Complex<f64>;
712
713 #[inline]
714 fn qr(&self) -> Result<(Array2<Complex<f64>>, Array2<Complex<f64>>), LapackError> {
715 let result = qr_complex_ndarray(self)?;
716 Ok((result.q, result.r))
717 }
718 }
719
720 pub trait Eig {
722 type Elem;
723
724 fn eig(&self) -> Result<(Array1<Self::Elem>, Array2<Self::Elem>), LapackError>;
725 }
726
727 impl Eig for Array2<Complex<f64>> {
731 type Elem = Complex<f64>;
732
733 #[inline]
734 fn eig(&self) -> Result<(Array1<Complex<f64>>, Array2<Complex<f64>>), LapackError> {
735 let (eigenvalues_real, eigenvectors) = eig_hermitian_ndarray(self)?;
738 let eigenvalues = eigenvalues_real.mapv(|x| Complex::new(x, 0.0));
739 Ok((eigenvalues, eigenvectors))
740 }
741 }
742
743 pub trait Cholesky {
745 type Elem;
746
747 fn cholesky(&self, uplo: UPLO) -> Result<Array2<Self::Elem>, LapackError>;
748 }
749
750 impl Cholesky for Array2<f64> {
751 type Elem = f64;
752
753 #[inline]
754 fn cholesky(&self, _uplo: UPLO) -> Result<Array2<f64>, LapackError> {
755 let result = cholesky_ndarray(self)?;
756 Ok(result.l)
757 }
758 }
759
760 impl Cholesky for Array2<Complex<f64>> {
761 type Elem = Complex<f64>;
762
763 #[inline]
764 fn cholesky(&self, _uplo: UPLO) -> Result<Array2<Complex<f64>>, LapackError> {
765 let result = cholesky_hermitian_ndarray(self)?;
766 Ok(result.l)
767 }
768 }
769}