scirs2_core/ndarray_ext/mod.rs
1//! Extended ndarray operations for scientific computing
2//!
3//! This module provides additional functionality for ndarray to support
4//! the advanced array operations necessary for a complete SciPy port.
5//! It implements core `NumPy`-like features that are not available in the
6//! base ndarray crate.
7
8/// Re-export essential ndarray types and macros for convenience
9/// Use ::ndarray to refer to the external crate directly
10pub use ::ndarray::{
11 s, Array, ArrayBase, ArrayView, ArrayViewMut, Axis, Data, DataMut, Dim, Dimension, Ix1, Ix2,
12 Ix3, Ix4, Ix5, Ix6, IxDyn, OwnedRepr, RemoveAxis, ScalarOperand, ShapeBuilder, ShapeError,
13 SliceInfo, ViewRepr, Zip,
14};
15
16// Re-export the array! macro for convenient array creation
17// This addresses the common issue where users expect array! to be available in scirs2-core
18// instead of requiring import from scirs2_autograd::ndarray
19pub use ::ndarray::{arr1, arr2, array};
20
21// Re-export type aliases for common array sizes
22pub use ::ndarray::{ArcArray1, ArcArray2};
23pub use ::ndarray::{Array0, Array1, Array2, Array3, Array4, Array5, Array6, ArrayD};
24pub use ::ndarray::{
25 ArrayView0, ArrayView1, ArrayView2, ArrayView3, ArrayView4, ArrayView5, ArrayView6, ArrayViewD,
26};
27pub use ::ndarray::{
28 ArrayViewMut0, ArrayViewMut1, ArrayViewMut2, ArrayViewMut3, ArrayViewMut4, ArrayViewMut5,
29 ArrayViewMut6, ArrayViewMutD,
30};
31
32/// Advanced indexing operations (`NumPy`-like boolean masking, fancy indexing, etc.)
33pub mod indexing;
34
35/// Statistical functions for ndarray arrays (mean, median, variance, correlation, etc.)
36pub mod stats;
37
38/// Matrix operations (eye, diag, kron, etc.)
39pub mod matrix;
40
41/// Array manipulation operations (flip, roll, tile, repeat, etc.)
42pub mod manipulation;
43
44/// Array reduction operations with SIMD acceleration (argmin, argmax, etc.)
45pub mod reduction;
46
47/// Data preprocessing operations with SIMD acceleration (normalization, standardization)
48pub mod preprocessing;
49
50/// Element-wise mathematical operations with SIMD acceleration (abs, sqrt, etc.)
51pub mod elementwise;
52
53/// Random array generation with full ndarray-rand compatibility and scientific computing extensions
54#[cfg(feature = "random")]
55pub mod random;
56
57/// Reshape a 2D array to a new shape without copying data when possible
58///
59/// # Arguments
60///
61/// * `array` - The input array
62/// * `shape` - The new shape (rows, cols), which must be compatible with the original shape
63///
64/// # Returns
65///
66/// A reshaped array view of the input array if possible, or a new array otherwise
67///
68/// # Examples
69///
70/// ```
71/// use scirs2_core::ndarray::array;
72/// use scirs2_core::ndarray_ext::reshape_2d;
73///
74/// let a = array![[1, 2], [3, 4]];
75/// let b = reshape_2d(a.view(), (4, 1)).unwrap();
76/// assert_eq!(b.shape(), &[4, 1]);
77/// assert_eq!(b[[0, 0]], 1);
78/// assert_eq!(b[[3, 0]], 4);
79/// ```
80#[allow(dead_code)]
81pub fn reshape_2d<T>(
82 array: ArrayView<T, Ix2>,
83 shape: (usize, usize),
84) -> Result<Array<T, Ix2>, &'static str>
85where
86 T: Clone + Default,
87{
88 let (rows, cols) = shape;
89 let total_elements = rows * cols;
90
91 // Check if the new shape is compatible with the original shape
92 if total_elements != array.len() {
93 return Err("New shape dimensions must match the total number of elements");
94 }
95
96 // Create a new array with the specified shape
97 let mut result = Array::<T, Ix2>::default(shape);
98
99 // Fill the result array with elements from the input array
100 let flat_iter = array.iter();
101 for (i, val) in flat_iter.enumerate() {
102 let r = i / cols;
103 let c = i % cols;
104 result[[r, c]] = val.clone();
105 }
106
107 Ok(result)
108}
109
110/// Stack 2D arrays along a given axis
111///
112/// # Arguments
113///
114/// * `arrays` - A slice of 2D arrays to stack
115/// * `axis` - The axis along which to stack (0 for rows, 1 for columns)
116///
117/// # Returns
118///
119/// A new array containing the stacked arrays
120///
121/// # Examples
122///
123/// ```
124/// use ndarray::{array, Axis};
125/// use scirs2_core::ndarray_ext::stack_2d;
126///
127/// let a = array![[1, 2], [3, 4]];
128/// let b = array![[5, 6], [7, 8]];
129/// let c = stack_2d(&[a.view(), b.view()], 0).unwrap();
130/// assert_eq!(c.shape(), &[4, 2]);
131/// ```
132#[allow(dead_code)]
133pub fn stack_2d<T>(arrays: &[ArrayView<T, Ix2>], axis: usize) -> Result<Array<T, Ix2>, &'static str>
134where
135 T: Clone + Default,
136{
137 if arrays.is_empty() {
138 return Err("No _arrays provided for stacking");
139 }
140
141 // Validate that all _arrays have the same shape
142 let firstshape = arrays[0].shape();
143 for array in arrays.iter().skip(1) {
144 if array.shape() != firstshape {
145 return Err("All _arrays must have the same shape for stacking");
146 }
147 }
148
149 let (rows, cols) = (firstshape[0], firstshape[1]);
150
151 // Calculate the new shape
152 let (new_rows, new_cols) = match axis {
153 0 => (rows * arrays.len(), cols), // Stack vertically
154 1 => (rows, cols * arrays.len()), // Stack horizontally
155 _ => return Err("Axis must be 0 or 1 for 2D _arrays"),
156 };
157
158 // Create a new array to hold the stacked result
159 let mut result = Array::<T, Ix2>::default((new_rows, new_cols));
160
161 // Copy data from the input _arrays to the result
162 match axis {
163 0 => {
164 // Stack vertically (along rows)
165 for (array_idx, array) in arrays.iter().enumerate() {
166 let start_row = array_idx * rows;
167 for r in 0..rows {
168 for c in 0..cols {
169 result[[start_row + r, c]] = array[[r, c]].clone();
170 }
171 }
172 }
173 }
174 1 => {
175 // Stack horizontally (along columns)
176 for (array_idx, array) in arrays.iter().enumerate() {
177 let start_col = array_idx * cols;
178 for r in 0..rows {
179 for c in 0..cols {
180 result[[r, start_col + c]] = array[[r, c]].clone();
181 }
182 }
183 }
184 }
185 _ => unreachable!(),
186 }
187
188 Ok(result)
189}
190
191/// Swap axes (transpose) of a 2D array
192///
193/// # Arguments
194///
195/// * `array` - The input 2D array
196///
197/// # Returns
198///
199/// A view of the input array with the axes swapped
200///
201/// # Examples
202///
203/// ```ignore
204/// use crate::ndarray::array;
205/// use scirs2_core::ndarray_ext::transpose_2d;
206///
207/// let a = array![[1, 2, 3], [4, 5, 6]];
208/// let b = transpose_2d(a.view());
209/// assert_eq!(b.shape(), &[3, 2]);
210/// assert_eq!(b[[0, 0]], 1);
211/// assert_eq!(b[[0, 1]], 4);
212/// ```
213#[allow(dead_code)]
214pub fn transpose_2d<T>(array: ArrayView<T, Ix2>) -> Array<T, Ix2>
215where
216 T: Clone,
217{
218 array.t().to_owned()
219}
220
221/// Split a 2D array into multiple sub-arrays along a given axis
222///
223/// # Arguments
224///
225/// * `array` - The input 2D array to split
226/// * `indices` - Indices where the array should be split
227/// * `axis` - The axis along which to split (0 for rows, 1 for columns)
228///
229/// # Returns
230///
231/// A vector of arrays resulting from the split
232///
233/// # Examples
234///
235/// ```
236/// use scirs2_core::ndarray::array;
237/// use scirs2_core::ndarray_ext::split_2d;
238///
239/// let a = array![[1, 2, 3, 4], [5, 6, 7, 8]];
240/// let result = split_2d(a.view(), &[2], 1).unwrap();
241/// assert_eq!(result.len(), 2);
242/// assert_eq!(result[0].shape(), &[2, 2]);
243/// assert_eq!(result[1].shape(), &[2, 2]);
244/// ```
245#[allow(dead_code)]
246pub fn split_2d<T>(
247 array: ArrayView<T, Ix2>,
248 indices: &[usize],
249 axis: usize,
250) -> Result<Vec<Array<T, Ix2>>, &'static str>
251where
252 T: Clone + Default,
253{
254 if indices.is_empty() {
255 return Ok(vec![array.to_owned()]);
256 }
257
258 let (rows, cols) = (array.shape()[0], array.shape()[1]);
259 let axis_len = if axis == 0 { rows } else { cols };
260
261 // Validate indices
262 for &idx in indices {
263 if idx >= axis_len {
264 return Err("Split index out of bounds");
265 }
266 }
267
268 // Sort indices to ensure they're in ascending order
269 let mut sorted_indices = indices.to_vec();
270 sorted_indices.sort_unstable();
271
272 // Calculate the sub-array boundaries
273 let mut starts = vec![0];
274 starts.extend_from_slice(&sorted_indices);
275
276 let mut ends = sorted_indices.clone();
277 ends.push(axis_len);
278
279 // Create the split sub-arrays
280 let mut result = Vec::with_capacity(starts.len());
281
282 match axis {
283 0 => {
284 // Split along rows
285 for (&start, &end) in starts.iter().zip(ends.iter()) {
286 let sub_rows = end - start;
287 let mut sub_array = Array::<T, Ix2>::default((sub_rows, cols));
288
289 for r in 0..sub_rows {
290 for c in 0..cols {
291 sub_array[[r, c]] = array[[start + r, c]].clone();
292 }
293 }
294
295 result.push(sub_array);
296 }
297 }
298 1 => {
299 // Split along columns
300 for (&start, &end) in starts.iter().zip(ends.iter()) {
301 let sub_cols = end - start;
302 let mut sub_array = Array::<T, Ix2>::default((rows, sub_cols));
303
304 for r in 0..rows {
305 for c in 0..sub_cols {
306 sub_array[[r, c]] = array[[r, start + c]].clone();
307 }
308 }
309
310 result.push(sub_array);
311 }
312 }
313 _ => return Err("Axis must be 0 or 1 for 2D arrays"),
314 }
315
316 Ok(result)
317}
318
319/// Take elements from a 2D array along a given axis using indices from another array
320///
321/// # Arguments
322///
323/// * `array` - The input 2D array
324/// * `indices` - Array of indices to take
325/// * `axis` - The axis along which to take values (0 for rows, 1 for columns)
326///
327/// # Returns
328///
329/// An array of values at the specified indices along the given axis
330///
331/// # Examples
332///
333/// ```
334/// use scirs2_core::ndarray::array;
335/// use scirs2_core::ndarray_ext::take_2d;
336///
337/// let a = array![[1, 2, 3], [4, 5, 6]];
338/// let indices = array![0_usize, 2];
339/// let result = take_2d(a.view(), indices.view(), 1).unwrap();
340/// assert_eq!(result.shape(), &[2, 2]);
341/// assert_eq!(result[[0, 0]], 1);
342/// assert_eq!(result[[0, 1]], 3);
343/// ```
344#[allow(dead_code)]
345pub fn take_2d<T>(
346 array: ArrayView<T, Ix2>,
347 indices: ArrayView<usize, Ix1>,
348 axis: usize,
349) -> Result<Array<T, Ix2>, &'static str>
350where
351 T: Clone + Default,
352{
353 let (rows, cols) = (array.shape()[0], array.shape()[1]);
354 let axis_len = if axis == 0 { rows } else { cols };
355
356 // Check that indices are within bounds
357 for &idx in indices.iter() {
358 if idx >= axis_len {
359 return Err("Index out of bounds");
360 }
361 }
362
363 // Create the result array with the appropriate shape
364 let (result_rows, result_cols) = match axis {
365 0 => (indices.len(), cols),
366 1 => (rows, indices.len()),
367 _ => return Err("Axis must be 0 or 1 for 2D arrays"),
368 };
369
370 let mut result = Array::<T, Ix2>::default((result_rows, result_cols));
371
372 // Fill the result array
373 match axis {
374 0 => {
375 // Take along rows
376 for (i, &idx) in indices.iter().enumerate() {
377 for j in 0..cols {
378 result[[i, j]] = array[[idx, j]].clone();
379 }
380 }
381 }
382 1 => {
383 // Take along columns
384 for i in 0..rows {
385 for (j, &idx) in indices.iter().enumerate() {
386 result[[i, j]] = array[[i, idx]].clone();
387 }
388 }
389 }
390 _ => unreachable!(),
391 }
392
393 Ok(result)
394}
395
396/// Filter an array using a boolean mask
397///
398/// # Arguments
399///
400/// * `array` - The input array
401/// * `mask` - Boolean mask of the same shape as the array
402///
403/// # Returns
404///
405/// A 1D array containing the elements where the mask is true
406///
407/// # Examples
408///
409/// ```ignore
410/// use crate::ndarray::array;
411/// use scirs2_core::ndarray_ext::mask_select;
412///
413/// let a = array![[1, 2, 3], [4, 5, 6]];
414/// let mask = array![[true, false, true], [false, true, false]];
415/// let result = mask_select(a.view(), mask.view()).unwrap();
416/// assert_eq!(result.shape(), &[3]);
417/// assert_eq!(result[0], 1);
418/// assert_eq!(result[1], 3);
419/// assert_eq!(result[2], 5);
420/// ```
421#[allow(dead_code)]
422pub fn mask_select<T>(
423 array: ArrayView<T, Ix2>,
424 mask: ArrayView<bool, Ix2>,
425) -> Result<Array<T, Ix1>, &'static str>
426where
427 T: Clone + Default,
428{
429 // Check that the mask has the same shape as the array
430 if array.shape() != mask.shape() {
431 return Err("Mask shape must match array shape");
432 }
433
434 // Count the number of true values in the mask
435 let true_count = mask.iter().filter(|&&x| x).count();
436
437 // Create the result array
438 let mut result = Array::<T, Ix1>::default(true_count);
439
440 // Fill the result array with elements where the mask is true
441 let mut idx = 0;
442 for (val, &m) in array.iter().zip(mask.iter()) {
443 if m {
444 result[idx] = val.clone();
445 idx += 1;
446 }
447 }
448
449 Ok(result)
450}
451
452/// Index a 2D array with a list of index arrays (fancy indexing)
453///
454/// # Arguments
455///
456/// * `array` - The input 2D array
457/// * `row_indices` - Array of row indices
458/// * `col_indices` - Array of column indices
459///
460/// # Returns
461///
462/// A 1D array containing the elements at the specified indices
463///
464/// # Examples
465///
466/// ```ignore
467/// use crate::ndarray::array;
468/// use scirs2_core::ndarray_ext::fancy_index_2d;
469///
470/// let a = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]];
471/// let row_indices = array![0, 2];
472/// let col_indices = array![0, 1];
473/// let result = fancy_index_2d(a.view(), row_indices.view(), col_indices.view()).unwrap();
474/// assert_eq!(result.shape(), &[2]);
475/// assert_eq!(result[0], 1);
476/// assert_eq!(result[1], 8);
477/// ```
478#[allow(dead_code)]
479pub fn fancy_index_2d<T>(
480 array: ArrayView<T, Ix2>,
481 row_indices: ArrayView<usize, Ix1>,
482 col_indices: ArrayView<usize, Ix1>,
483) -> Result<Array<T, Ix1>, &'static str>
484where
485 T: Clone + Default,
486{
487 // Check that all index arrays have the same length
488 let result_size = row_indices.len();
489 if col_indices.len() != result_size {
490 return Err("Row and column index arrays must have the same length");
491 }
492
493 let (rows, cols) = (array.shape()[0], array.shape()[1]);
494
495 // Check that _indices are within bounds
496 for &idx in row_indices.iter() {
497 if idx >= rows {
498 return Err("Row index out of bounds");
499 }
500 }
501
502 for &idx in col_indices.iter() {
503 if idx >= cols {
504 return Err("Column index out of bounds");
505 }
506 }
507
508 // Create the result array
509 let mut result = Array::<T, Ix1>::default(result_size);
510
511 // Fill the result array
512 for i in 0..result_size {
513 let row = row_indices[i];
514 let col = col_indices[i];
515 result[i] = array[[row, col]].clone();
516 }
517
518 Ok(result)
519}
520
521/// Select elements from an array where a condition is true
522///
523/// # Arguments
524///
525/// * `array` - The input array
526/// * `condition` - A function that takes a reference to an element and returns a bool
527///
528/// # Returns
529///
530/// A 1D array containing the elements where the condition is true
531///
532/// # Examples
533///
534/// ```ignore
535/// use crate::ndarray::array;
536/// use scirs2_core::ndarray_ext::where_condition;
537///
538/// let a = array![[1, 2, 3], [4, 5, 6]];
539/// let result = where_condition(a.view(), |&x| x > 3).unwrap();
540/// assert_eq!(result.shape(), &[3]);
541/// assert_eq!(result[0], 4);
542/// assert_eq!(result[1], 5);
543/// assert_eq!(result[2], 6);
544/// ```
545#[allow(dead_code)]
546pub fn where_condition<T, F>(
547 array: ArrayView<T, Ix2>,
548 condition: F,
549) -> Result<Array<T, Ix1>, &'static str>
550where
551 T: Clone + Default,
552 F: Fn(&T) -> bool,
553{
554 // Build a boolean mask array based on the condition
555 let mask = array.map(condition);
556
557 // Use the mask_select function to select elements
558 mask_select(array, mask.view())
559}
560
561/// Check if two shapes are broadcast compatible
562///
563/// # Arguments
564///
565/// * `shape1` - First shape as a slice
566/// * `shape2` - Second shape as a slice
567///
568/// # Returns
569///
570/// `true` if the shapes are broadcast compatible, `false` otherwise
571///
572/// # Examples
573///
574/// ```ignore
575/// use scirs2_core::ndarray_ext::is_broadcast_compatible;
576///
577/// assert!(is_broadcast_compatible(&[2, 3], &[3]));
578/// // This example has dimensions that don't match (5 vs 3 in dimension 0)
579/// // and aren't broadcasting compatible (neither is 1)
580/// assert!(!is_broadcast_compatible(&[5, 1, 4], &[3, 1, 1]));
581/// assert!(!is_broadcast_compatible(&[2, 3], &[4]));
582/// ```
583#[allow(dead_code)]
584pub fn is_broadcast_compatible(shape1: &[usize], shape2: &[usize]) -> bool {
585 // Align shapes to have the same dimensionality by prepending with 1s
586 let max_dim = shape1.len().max(shape2.len());
587
588 // Fill in 1s for missing dimensions
589 let get_dim = |shape: &[usize], i: usize| -> usize {
590 let offset = max_dim - shape.len();
591 if i < offset {
592 1 // Implicit dimension of size 1
593 } else {
594 shape[i - offset]
595 }
596 };
597
598 // Check broadcasting rules for each dimension
599 for i in 0..max_dim {
600 let dim1 = get_dim(shape1, i);
601 let dim2 = get_dim(shape2, i);
602
603 // Dimensions must either be the same or one of them must be 1
604 if dim1 != dim2 && dim1 != 1 && dim2 != 1 {
605 return false;
606 }
607 }
608
609 true
610}
611
612/// Calculate the broadcasted shape from two input shapes
613///
614/// # Arguments
615///
616/// * `shape1` - First shape as a slice
617/// * `shape2` - Second shape as a slice
618///
619/// # Returns
620///
621/// The broadcasted shape as a `Vec<usize>`, or `None` if the shapes are incompatible
622#[allow(dead_code)]
623pub fn broadcastshape(shape1: &[usize], shape2: &[usize]) -> Option<Vec<usize>> {
624 if !is_broadcast_compatible(shape1, shape2) {
625 return None;
626 }
627
628 // Align shapes to have the same dimensionality
629 let max_dim = shape1.len().max(shape2.len());
630 let mut result = Vec::with_capacity(max_dim);
631
632 // Fill in 1s for missing dimensions
633 let get_dim = |shape: &[usize], i: usize| -> usize {
634 let offset = max_dim - shape.len();
635 if i < offset {
636 1 // Implicit dimension of size 1
637 } else {
638 shape[i - offset]
639 }
640 };
641
642 // Calculate the broadcasted shape
643 for i in 0..max_dim {
644 let dim1 = get_dim(shape1, i);
645 let dim2 = get_dim(shape2, i);
646
647 // The broadcasted dimension is the maximum of the two
648 result.push(dim1.max(dim2));
649 }
650
651 Some(result)
652}
653
654/// Broadcast a 1D array to a 2D shape by repeating it along the specified axis
655///
656/// # Arguments
657///
658/// * `array` - The input 1D array
659/// * `repeats` - Number of times to repeat the array
660/// * `axis` - Axis along which to repeat (0 for rows, 1 for columns)
661///
662/// # Returns
663///
664/// A 2D array with the input array repeated along the specified axis
665///
666/// # Examples
667///
668/// ```ignore
669/// use crate::ndarray::array;
670/// use scirs2_core::ndarray_ext::broadcast_1d_to_2d;
671///
672/// let a = array![1, 2, 3];
673/// let b = broadcast_1d_to_2d(a.view(), 2, 0).unwrap();
674/// assert_eq!(b.shape(), &[2, 3]);
675/// assert_eq!(b[[0, 0]], 1);
676/// assert_eq!(b[[1, 0]], 1);
677/// ```
678#[allow(dead_code)]
679pub fn broadcast_1d_to_2d<T>(
680 array: ArrayView<T, Ix1>,
681 repeats: usize,
682 axis: usize,
683) -> Result<Array<T, Ix2>, &'static str>
684where
685 T: Clone + Default,
686{
687 let len = array.len();
688
689 // Create the result array with the appropriate shape
690 let (rows, cols) = match axis {
691 0 => (repeats, len), // Broadcast along rows
692 1 => (len, repeats), // Broadcast along columns
693 _ => return Err("Axis must be 0 or 1"),
694 };
695
696 let mut result = Array::<T, Ix2>::default((rows, cols));
697
698 // Fill the result array
699 match axis {
700 0 => {
701 // Broadcast along rows
702 for i in 0..repeats {
703 for j in 0..len {
704 result[[i, j]] = array[j].clone();
705 }
706 }
707 }
708 1 => {
709 // Broadcast along columns
710 for i in 0..len {
711 for j in 0..repeats {
712 result[[i, j]] = array[i].clone();
713 }
714 }
715 }
716 _ => unreachable!(),
717 }
718
719 Ok(result)
720}
721
722/// Apply an element-wise binary operation to two arrays with broadcasting
723///
724/// # Arguments
725///
726/// * `a` - First array (2D)
727/// * `b` - Second array (can be 1D or 2D)
728/// * `op` - Binary operation to apply to each pair of elements
729///
730/// # Returns
731///
732/// A 2D array containing the result of the operation applied element-wise
733///
734/// # Examples
735///
736/// ```ignore
737/// use crate::ndarray::array;
738/// use scirs2_core::ndarray_ext::broadcast_apply;
739///
740/// let a = array![[1, 2, 3], [4, 5, 6]];
741/// let b = array![10, 20, 30];
742/// let result = broadcast_apply(a.view(), b.view(), |x, y| x + y).unwrap();
743/// assert_eq!(result.shape(), &[2, 3]);
744/// assert_eq!(result[[0, 0]], 11);
745/// assert_eq!(result[[1, 2]], 36);
746/// ```
747#[allow(dead_code)]
748pub fn broadcast_apply<T, R, F>(
749 a: ArrayView<T, Ix2>,
750 b: ArrayView<T, Ix1>,
751 op: F,
752) -> Result<Array<R, Ix2>, &'static str>
753where
754 T: Clone + Default,
755 R: Clone + Default,
756 F: Fn(&T, &T) -> R,
757{
758 let (a_rows, a_cols) = (a.shape()[0], a.shape()[1]);
759 let b_len = b.len();
760
761 // Check that the arrays are broadcast compatible
762 if a_cols != b_len {
763 return Err("Arrays are not broadcast compatible");
764 }
765
766 // Create the result array
767 let mut result = Array::<R, Ix2>::default((a_rows, a_cols));
768
769 // Apply the operation element-wise
770 for i in 0..a_rows {
771 for j in 0..a_cols {
772 result[[i, j]] = op(&a[[i, j]], &b[j]);
773 }
774 }
775
776 Ok(result)
777}
778
779#[cfg(test)]
780mod tests {
781 use super::*;
782 use crate::ndarray::array;
783
784 #[test]
785 fn test_reshape_2d() {
786 let a = array![[1, 2], [3, 4]];
787 let b = reshape_2d(a.view(), (4, 1)).unwrap();
788 assert_eq!(b.shape(), &[4, 1]);
789 assert_eq!(b[[0, 0]], 1);
790 assert_eq!(b[[1, 0]], 2);
791 assert_eq!(b[[2, 0]], 3);
792 assert_eq!(b[[3, 0]], 4);
793
794 // Test invalid shape
795 let result = reshape_2d(a.view(), (3, 1));
796 assert!(result.is_err());
797 }
798
799 #[test]
800 fn test_stack_2d() {
801 let a = array![[1, 2], [3, 4]];
802 let b = array![[5, 6], [7, 8]];
803
804 // Stack vertically (along axis 0)
805 let c = stack_2d(&[a.view(), b.view()], 0).unwrap();
806 assert_eq!(c.shape(), &[4, 2]);
807 assert_eq!(c[[0, 0]], 1);
808 assert_eq!(c[[1, 0]], 3);
809 assert_eq!(c[[2, 0]], 5);
810 assert_eq!(c[[3, 0]], 7);
811
812 // Stack horizontally (along axis 1)
813 let d = stack_2d(&[a.view(), b.view()], 1).unwrap();
814 assert_eq!(d.shape(), &[2, 4]);
815 assert_eq!(d[[0, 0]], 1);
816 assert_eq!(d[[0, 1]], 2);
817 assert_eq!(d[[0, 2]], 5);
818 assert_eq!(d[[0, 3]], 6);
819 }
820
821 #[test]
822 fn test_transpose_2d() {
823 let a = array![[1, 2, 3], [4, 5, 6]];
824 let b = transpose_2d(a.view());
825 assert_eq!(b.shape(), &[3, 2]);
826 assert_eq!(b[[0, 0]], 1);
827 assert_eq!(b[[0, 1]], 4);
828 assert_eq!(b[[1, 0]], 2);
829 assert_eq!(b[[2, 1]], 6);
830 }
831
832 #[test]
833 fn test_split_2d() {
834 let a = array![[1, 2, 3, 4], [5, 6, 7, 8]];
835
836 // Split along columns at index 2
837 let result = split_2d(a.view(), &[2], 1).unwrap();
838 assert_eq!(result.len(), 2);
839 assert_eq!(result[0].shape(), &[2, 2]);
840 assert_eq!(result[0][[0, 0]], 1);
841 assert_eq!(result[0][[0, 1]], 2);
842 assert_eq!(result[0][[1, 0]], 5);
843 assert_eq!(result[0][[1, 1]], 6);
844 assert_eq!(result[1].shape(), &[2, 2]);
845 assert_eq!(result[1][[0, 0]], 3);
846 assert_eq!(result[1][[0, 1]], 4);
847 assert_eq!(result[1][[1, 0]], 7);
848 assert_eq!(result[1][[1, 1]], 8);
849
850 // Split along rows at index 1
851 let result = split_2d(a.view(), &[1], 0).unwrap();
852 assert_eq!(result.len(), 2);
853 assert_eq!(result[0].shape(), &[1, 4]);
854 assert_eq!(result[1].shape(), &[1, 4]);
855 }
856
857 #[test]
858 fn test_take_2d() {
859 let a = array![[1, 2, 3], [4, 5, 6]];
860 let indices = array![0, 2];
861
862 // Take along columns
863 let result = take_2d(a.view(), indices.view(), 1).unwrap();
864 assert_eq!(result.shape(), &[2, 2]);
865 assert_eq!(result[[0, 0]], 1);
866 assert_eq!(result[[0, 1]], 3);
867 assert_eq!(result[[1, 0]], 4);
868 assert_eq!(result[[1, 1]], 6);
869 }
870
871 #[test]
872 fn test_mask_select() {
873 let a = array![[1, 2, 3], [4, 5, 6]];
874 let mask = array![[true, false, true], [false, true, false]];
875
876 let result = mask_select(a.view(), mask.view()).unwrap();
877 assert_eq!(result.shape(), &[3]);
878 assert_eq!(result[0], 1);
879 assert_eq!(result[1], 3);
880 assert_eq!(result[2], 5);
881 }
882
883 #[test]
884 fn test_fancy_index_2d() {
885 let a = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]];
886 let row_indices = array![0, 2];
887 let col_indices = array![0, 1];
888
889 let result = fancy_index_2d(a.view(), row_indices.view(), col_indices.view()).unwrap();
890 assert_eq!(result.shape(), &[2]);
891 assert_eq!(result[0], 1);
892 assert_eq!(result[1], 8);
893 }
894
895 #[test]
896 fn test_where_condition() {
897 let a = array![[1, 2, 3], [4, 5, 6]];
898 let result = where_condition(a.view(), |&x| x > 3).unwrap();
899 assert_eq!(result.shape(), &[3]);
900 assert_eq!(result[0], 4);
901 assert_eq!(result[1], 5);
902 assert_eq!(result[2], 6);
903 }
904
905 #[test]
906 fn test_broadcast_1d_to_2d() {
907 let a = array![1, 2, 3];
908
909 // Broadcast along rows (axis 0)
910 let b = broadcast_1d_to_2d(a.view(), 2, 0).unwrap();
911 assert_eq!(b.shape(), &[2, 3]);
912 assert_eq!(b[[0, 0]], 1);
913 assert_eq!(b[[0, 1]], 2);
914 assert_eq!(b[[1, 0]], 1);
915 assert_eq!(b[[1, 2]], 3);
916
917 // Broadcast along columns (axis 1)
918 let c = broadcast_1d_to_2d(a.view(), 2, 1).unwrap();
919 assert_eq!(c.shape(), &[3, 2]);
920 assert_eq!(c[[0, 0]], 1);
921 assert_eq!(c[[0, 1]], 1);
922 assert_eq!(c[[1, 0]], 2);
923 assert_eq!(c[[2, 1]], 3);
924 }
925
926 #[test]
927 fn test_broadcast_apply() {
928 let a = array![[1, 2, 3], [4, 5, 6]];
929 let b = array![10, 20, 30];
930
931 let result = broadcast_apply(a.view(), b.view(), |x, y| x + y).unwrap();
932 assert_eq!(result.shape(), &[2, 3]);
933 assert_eq!(result[[0, 0]], 11);
934 assert_eq!(result[[0, 1]], 22);
935 assert_eq!(result[[0, 2]], 33);
936 assert_eq!(result[[1, 0]], 14);
937 assert_eq!(result[[1, 1]], 25);
938 assert_eq!(result[[1, 2]], 36);
939
940 let result = broadcast_apply(a.view(), b.view(), |x, y| x * y).unwrap();
941 assert_eq!(result.shape(), &[2, 3]);
942 assert_eq!(result[[0, 0]], 10);
943 assert_eq!(result[[0, 1]], 40);
944 assert_eq!(result[[0, 2]], 90);
945 assert_eq!(result[[1, 0]], 40);
946 assert_eq!(result[[1, 1]], 100);
947 assert_eq!(result[[1, 2]], 180);
948 }
949}