1pub use ndarray::{
10 s, Array, ArrayBase, ArrayView, ArrayViewMut, Axis, Data, DataMut, Dim, Dimension, Ix1, Ix2,
11 Ix3, Ix4, Ix5, Ix6, IxDyn, OwnedRepr, RemoveAxis, ScalarOperand, ShapeBuilder, ShapeError,
12 SliceInfo, ViewRepr, Zip,
13};
14
15pub use ndarray::{ArcArray1, ArcArray2};
17pub use ndarray::{Array0, Array1, Array2, Array3, Array4, Array5, Array6, ArrayD};
18pub use ndarray::{
19 ArrayView0, ArrayView1, ArrayView2, ArrayView3, ArrayView4, ArrayView5, ArrayView6, ArrayViewD,
20};
21pub use ndarray::{
22 ArrayViewMut0, ArrayViewMut1, ArrayViewMut2, ArrayViewMut3, ArrayViewMut4, ArrayViewMut5,
23 ArrayViewMut6, ArrayViewMutD,
24};
25
26pub mod indexing;
28
29pub mod stats;
31
32pub mod matrix;
34
35pub mod manipulation;
37
38#[allow(dead_code)]
62pub fn reshape_2d<T>(
63 array: ArrayView<T, Ix2>,
64 shape: (usize, usize),
65) -> Result<Array<T, Ix2>, &'static str>
66where
67 T: Clone + Default,
68{
69 let (rows, cols) = shape;
70 let total_elements = rows * cols;
71
72 if total_elements != array.len() {
74 return Err("New shape dimensions must match the total number of elements");
75 }
76
77 let mut result = Array::<T, Ix2>::default(shape);
79
80 let flat_iter = array.iter();
82 for (i, val) in flat_iter.enumerate() {
83 let r = i / cols;
84 let c = i % cols;
85 result[[r, c]] = val.clone();
86 }
87
88 Ok(result)
89}
90
91#[allow(dead_code)]
114pub fn stack_2d<T>(arrays: &[ArrayView<T, Ix2>], axis: usize) -> Result<Array<T, Ix2>, &'static str>
115where
116 T: Clone + Default,
117{
118 if arrays.is_empty() {
119 return Err("No _arrays provided for stacking");
120 }
121
122 let firstshape = arrays[0].shape();
124 for array in arrays.iter().skip(1) {
125 if array.shape() != firstshape {
126 return Err("All _arrays must have the same shape for stacking");
127 }
128 }
129
130 let (rows, cols) = (firstshape[0], firstshape[1]);
131
132 let (new_rows, new_cols) = match axis {
134 0 => (rows * arrays.len(), cols), 1 => (rows, cols * arrays.len()), _ => return Err("Axis must be 0 or 1 for 2D _arrays"),
137 };
138
139 let mut result = Array::<T, Ix2>::default((new_rows, new_cols));
141
142 match axis {
144 0 => {
145 for (array_idx, array) in arrays.iter().enumerate() {
147 let start_row = array_idx * rows;
148 for r in 0..rows {
149 for c in 0..cols {
150 result[[start_row + r, c]] = array[[r, c]].clone();
151 }
152 }
153 }
154 }
155 1 => {
156 for (array_idx, array) in arrays.iter().enumerate() {
158 let start_col = array_idx * cols;
159 for r in 0..rows {
160 for c in 0..cols {
161 result[[r, start_col + c]] = array[[r, c]].clone();
162 }
163 }
164 }
165 }
166 _ => unreachable!(),
167 }
168
169 Ok(result)
170}
171
172#[allow(dead_code)]
195pub fn transpose_2d<T>(array: ArrayView<T, Ix2>) -> Array<T, Ix2>
196where
197 T: Clone,
198{
199 array.t().to_owned()
200}
201
202#[allow(dead_code)]
227pub fn split_2d<T>(
228 array: ArrayView<T, Ix2>,
229 indices: &[usize],
230 axis: usize,
231) -> Result<Vec<Array<T, Ix2>>, &'static str>
232where
233 T: Clone + Default,
234{
235 if indices.is_empty() {
236 return Ok(vec![array.to_owned()]);
237 }
238
239 let (rows, cols) = (array.shape()[0], array.shape()[1]);
240 let axis_len = if axis == 0 { rows } else { cols };
241
242 for &idx in indices {
244 if idx >= axis_len {
245 return Err("Split index out of bounds");
246 }
247 }
248
249 let mut sorted_indices = indices.to_vec();
251 sorted_indices.sort_unstable();
252
253 let mut starts = vec![0];
255 starts.extend_from_slice(&sorted_indices);
256
257 let mut ends = sorted_indices.clone();
258 ends.push(axis_len);
259
260 let mut result = Vec::with_capacity(starts.len());
262
263 match axis {
264 0 => {
265 for (&start, &end) in starts.iter().zip(ends.iter()) {
267 let sub_rows = end - start;
268 let mut sub_array = Array::<T, Ix2>::default((sub_rows, cols));
269
270 for r in 0..sub_rows {
271 for c in 0..cols {
272 sub_array[[r, c]] = array[[start + r, c]].clone();
273 }
274 }
275
276 result.push(sub_array);
277 }
278 }
279 1 => {
280 for (&start, &end) in starts.iter().zip(ends.iter()) {
282 let sub_cols = end - start;
283 let mut sub_array = Array::<T, Ix2>::default((rows, sub_cols));
284
285 for r in 0..rows {
286 for c in 0..sub_cols {
287 sub_array[[r, c]] = array[[r, start + c]].clone();
288 }
289 }
290
291 result.push(sub_array);
292 }
293 }
294 _ => return Err("Axis must be 0 or 1 for 2D arrays"),
295 }
296
297 Ok(result)
298}
299
300#[allow(dead_code)]
326pub fn take_2d<T>(
327 array: ArrayView<T, Ix2>,
328 indices: ArrayView<usize, Ix1>,
329 axis: usize,
330) -> Result<Array<T, Ix2>, &'static str>
331where
332 T: Clone + Default,
333{
334 let (rows, cols) = (array.shape()[0], array.shape()[1]);
335 let axis_len = if axis == 0 { rows } else { cols };
336
337 for &idx in indices.iter() {
339 if idx >= axis_len {
340 return Err("Index out of bounds");
341 }
342 }
343
344 let (result_rows, result_cols) = match axis {
346 0 => (indices.len(), cols),
347 1 => (rows, indices.len()),
348 _ => return Err("Axis must be 0 or 1 for 2D arrays"),
349 };
350
351 let mut result = Array::<T, Ix2>::default((result_rows, result_cols));
352
353 match axis {
355 0 => {
356 for (i, &idx) in indices.iter().enumerate() {
358 for j in 0..cols {
359 result[[i, j]] = array[[idx, j]].clone();
360 }
361 }
362 }
363 1 => {
364 for i in 0..rows {
366 for (j, &idx) in indices.iter().enumerate() {
367 result[[i, j]] = array[[i, idx]].clone();
368 }
369 }
370 }
371 _ => unreachable!(),
372 }
373
374 Ok(result)
375}
376
377#[allow(dead_code)]
403pub fn mask_select<T>(
404 array: ArrayView<T, Ix2>,
405 mask: ArrayView<bool, Ix2>,
406) -> Result<Array<T, Ix1>, &'static str>
407where
408 T: Clone + Default,
409{
410 if array.shape() != mask.shape() {
412 return Err("Mask shape must match array shape");
413 }
414
415 let true_count = mask.iter().filter(|&&x| x).count();
417
418 let mut result = Array::<T, Ix1>::default(true_count);
420
421 let mut idx = 0;
423 for (val, &m) in array.iter().zip(mask.iter()) {
424 if m {
425 result[idx] = val.clone();
426 idx += 1;
427 }
428 }
429
430 Ok(result)
431}
432
433#[allow(dead_code)]
460pub fn fancy_index_2d<T>(
461 array: ArrayView<T, Ix2>,
462 row_indices: ArrayView<usize, Ix1>,
463 col_indices: ArrayView<usize, Ix1>,
464) -> Result<Array<T, Ix1>, &'static str>
465where
466 T: Clone + Default,
467{
468 let result_size = row_indices.len();
470 if col_indices.len() != result_size {
471 return Err("Row and column index arrays must have the same length");
472 }
473
474 let (rows, cols) = (array.shape()[0], array.shape()[1]);
475
476 for &idx in row_indices.iter() {
478 if idx >= rows {
479 return Err("Row index out of bounds");
480 }
481 }
482
483 for &idx in col_indices.iter() {
484 if idx >= cols {
485 return Err("Column index out of bounds");
486 }
487 }
488
489 let mut result = Array::<T, Ix1>::default(result_size);
491
492 for i in 0..result_size {
494 let row = row_indices[i];
495 let col = col_indices[i];
496 result[i] = array[[row, col]].clone();
497 }
498
499 Ok(result)
500}
501
502#[allow(dead_code)]
527pub fn where_condition<T, F>(
528 array: ArrayView<T, Ix2>,
529 condition: F,
530) -> Result<Array<T, Ix1>, &'static str>
531where
532 T: Clone + Default,
533 F: Fn(&T) -> bool,
534{
535 let mask = array.map(condition);
537
538 mask_select(array, mask.view())
540}
541
542#[allow(dead_code)]
565pub fn is_broadcast_compatible(shape1: &[usize], shape2: &[usize]) -> bool {
566 let max_dim = shape1.len().max(shape2.len());
568
569 let get_dim = |shape: &[usize], i: usize| -> usize {
571 let offset = max_dim - shape.len();
572 if i < offset {
573 1 } else {
575 shape[i - offset]
576 }
577 };
578
579 for i in 0..max_dim {
581 let dim1 = get_dim(shape1, i);
582 let dim2 = get_dim(shape2, i);
583
584 if dim1 != dim2 && dim1 != 1 && dim2 != 1 {
586 return false;
587 }
588 }
589
590 true
591}
592
593#[allow(dead_code)]
604pub fn broadcastshape(shape1: &[usize], shape2: &[usize]) -> Option<Vec<usize>> {
605 if !is_broadcast_compatible(shape1, shape2) {
606 return None;
607 }
608
609 let max_dim = shape1.len().max(shape2.len());
611 let mut result = Vec::with_capacity(max_dim);
612
613 let get_dim = |shape: &[usize], i: usize| -> usize {
615 let offset = max_dim - shape.len();
616 if i < offset {
617 1 } else {
619 shape[i - offset]
620 }
621 };
622
623 for i in 0..max_dim {
625 let dim1 = get_dim(shape1, i);
626 let dim2 = get_dim(shape2, i);
627
628 result.push(dim1.max(dim2));
630 }
631
632 Some(result)
633}
634
635#[allow(dead_code)]
660pub fn broadcast_1d_to_2d<T>(
661 array: ArrayView<T, Ix1>,
662 repeats: usize,
663 axis: usize,
664) -> Result<Array<T, Ix2>, &'static str>
665where
666 T: Clone + Default,
667{
668 let len = array.len();
669
670 let (rows, cols) = match axis {
672 0 => (repeats, len), 1 => (len, repeats), _ => return Err("Axis must be 0 or 1"),
675 };
676
677 let mut result = Array::<T, Ix2>::default((rows, cols));
678
679 match axis {
681 0 => {
682 for i in 0..repeats {
684 for j in 0..len {
685 result[[i, j]] = array[j].clone();
686 }
687 }
688 }
689 1 => {
690 for i in 0..len {
692 for j in 0..repeats {
693 result[[i, j]] = array[i].clone();
694 }
695 }
696 }
697 _ => unreachable!(),
698 }
699
700 Ok(result)
701}
702
703#[allow(dead_code)]
729pub fn broadcast_apply<T, R, F>(
730 a: ArrayView<T, Ix2>,
731 b: ArrayView<T, Ix1>,
732 op: F,
733) -> Result<Array<R, Ix2>, &'static str>
734where
735 T: Clone + Default,
736 R: Clone + Default,
737 F: Fn(&T, &T) -> R,
738{
739 let (a_rows, a_cols) = (a.shape()[0], a.shape()[1]);
740 let b_len = b.len();
741
742 if a_cols != b_len {
744 return Err("Arrays are not broadcast compatible");
745 }
746
747 let mut result = Array::<R, Ix2>::default((a_rows, a_cols));
749
750 for i in 0..a_rows {
752 for j in 0..a_cols {
753 result[[i, j]] = op(&a[[i, j]], &b[j]);
754 }
755 }
756
757 Ok(result)
758}
759
760#[cfg(test)]
761mod tests {
762 use super::*;
763 use ndarray::array;
764
765 #[test]
766 fn test_reshape_2d() {
767 let a = array![[1, 2], [3, 4]];
768 let b = reshape_2d(a.view(), (4, 1)).unwrap();
769 assert_eq!(b.shape(), &[4, 1]);
770 assert_eq!(b[[0, 0]], 1);
771 assert_eq!(b[[1, 0]], 2);
772 assert_eq!(b[[2, 0]], 3);
773 assert_eq!(b[[3, 0]], 4);
774
775 let result = reshape_2d(a.view(), (3, 1));
777 assert!(result.is_err());
778 }
779
780 #[test]
781 fn test_stack_2d() {
782 let a = array![[1, 2], [3, 4]];
783 let b = array![[5, 6], [7, 8]];
784
785 let c = stack_2d(&[a.view(), b.view()], 0).unwrap();
787 assert_eq!(c.shape(), &[4, 2]);
788 assert_eq!(c[[0, 0]], 1);
789 assert_eq!(c[[1, 0]], 3);
790 assert_eq!(c[[2, 0]], 5);
791 assert_eq!(c[[3, 0]], 7);
792
793 let d = stack_2d(&[a.view(), b.view()], 1).unwrap();
795 assert_eq!(d.shape(), &[2, 4]);
796 assert_eq!(d[[0, 0]], 1);
797 assert_eq!(d[[0, 1]], 2);
798 assert_eq!(d[[0, 2]], 5);
799 assert_eq!(d[[0, 3]], 6);
800 }
801
802 #[test]
803 fn test_transpose_2d() {
804 let a = array![[1, 2, 3], [4, 5, 6]];
805 let b = transpose_2d(a.view());
806 assert_eq!(b.shape(), &[3, 2]);
807 assert_eq!(b[[0, 0]], 1);
808 assert_eq!(b[[0, 1]], 4);
809 assert_eq!(b[[1, 0]], 2);
810 assert_eq!(b[[2, 1]], 6);
811 }
812
813 #[test]
814 fn test_split_2d() {
815 let a = array![[1, 2, 3, 4], [5, 6, 7, 8]];
816
817 let result = split_2d(a.view(), &[2], 1).unwrap();
819 assert_eq!(result.len(), 2);
820 assert_eq!(result[0].shape(), &[2, 2]);
821 assert_eq!(result[0][[0, 0]], 1);
822 assert_eq!(result[0][[0, 1]], 2);
823 assert_eq!(result[0][[1, 0]], 5);
824 assert_eq!(result[0][[1, 1]], 6);
825 assert_eq!(result[1].shape(), &[2, 2]);
826 assert_eq!(result[1][[0, 0]], 3);
827 assert_eq!(result[1][[0, 1]], 4);
828 assert_eq!(result[1][[1, 0]], 7);
829 assert_eq!(result[1][[1, 1]], 8);
830
831 let result = split_2d(a.view(), &[1], 0).unwrap();
833 assert_eq!(result.len(), 2);
834 assert_eq!(result[0].shape(), &[1, 4]);
835 assert_eq!(result[1].shape(), &[1, 4]);
836 }
837
838 #[test]
839 fn test_take_2d() {
840 let a = array![[1, 2, 3], [4, 5, 6]];
841 let indices = array![0, 2];
842
843 let result = take_2d(a.view(), indices.view(), 1).unwrap();
845 assert_eq!(result.shape(), &[2, 2]);
846 assert_eq!(result[[0, 0]], 1);
847 assert_eq!(result[[0, 1]], 3);
848 assert_eq!(result[[1, 0]], 4);
849 assert_eq!(result[[1, 1]], 6);
850 }
851
852 #[test]
853 fn test_mask_select() {
854 let a = array![[1, 2, 3], [4, 5, 6]];
855 let mask = array![[true, false, true], [false, true, false]];
856
857 let result = mask_select(a.view(), mask.view()).unwrap();
858 assert_eq!(result.shape(), &[3]);
859 assert_eq!(result[0], 1);
860 assert_eq!(result[1], 3);
861 assert_eq!(result[2], 5);
862 }
863
864 #[test]
865 fn test_fancy_index_2d() {
866 let a = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]];
867 let row_indices = array![0, 2];
868 let col_indices = array![0, 1];
869
870 let result = fancy_index_2d(a.view(), row_indices.view(), col_indices.view()).unwrap();
871 assert_eq!(result.shape(), &[2]);
872 assert_eq!(result[0], 1);
873 assert_eq!(result[1], 8);
874 }
875
876 #[test]
877 fn test_where_condition() {
878 let a = array![[1, 2, 3], [4, 5, 6]];
879 let result = where_condition(a.view(), |&x| x > 3).unwrap();
880 assert_eq!(result.shape(), &[3]);
881 assert_eq!(result[0], 4);
882 assert_eq!(result[1], 5);
883 assert_eq!(result[2], 6);
884 }
885
886 #[test]
887 fn test_broadcast_1d_to_2d() {
888 let a = array![1, 2, 3];
889
890 let b = broadcast_1d_to_2d(a.view(), 2, 0).unwrap();
892 assert_eq!(b.shape(), &[2, 3]);
893 assert_eq!(b[[0, 0]], 1);
894 assert_eq!(b[[0, 1]], 2);
895 assert_eq!(b[[1, 0]], 1);
896 assert_eq!(b[[1, 2]], 3);
897
898 let c = broadcast_1d_to_2d(a.view(), 2, 1).unwrap();
900 assert_eq!(c.shape(), &[3, 2]);
901 assert_eq!(c[[0, 0]], 1);
902 assert_eq!(c[[0, 1]], 1);
903 assert_eq!(c[[1, 0]], 2);
904 assert_eq!(c[[2, 1]], 3);
905 }
906
907 #[test]
908 fn test_broadcast_apply() {
909 let a = array![[1, 2, 3], [4, 5, 6]];
910 let b = array![10, 20, 30];
911
912 let result = broadcast_apply(a.view(), b.view(), |x, y| x + y).unwrap();
913 assert_eq!(result.shape(), &[2, 3]);
914 assert_eq!(result[[0, 0]], 11);
915 assert_eq!(result[[0, 1]], 22);
916 assert_eq!(result[[0, 2]], 33);
917 assert_eq!(result[[1, 0]], 14);
918 assert_eq!(result[[1, 1]], 25);
919 assert_eq!(result[[1, 2]], 36);
920
921 let result = broadcast_apply(a.view(), b.view(), |x, y| x * y).unwrap();
922 assert_eq!(result.shape(), &[2, 3]);
923 assert_eq!(result[[0, 0]], 10);
924 assert_eq!(result[[0, 1]], 40);
925 assert_eq!(result[[0, 2]], 90);
926 assert_eq!(result[[1, 0]], 40);
927 assert_eq!(result[[1, 1]], 100);
928 assert_eq!(result[[1, 2]], 180);
929 }
930}