scirs2_core/ndarray_ext/
mod.rs

1//! Extended ndarray operations for scientific computing
2//!
3//! This module provides additional functionality for ndarray to support
4//! the advanced array operations necessary for a complete SciPy port.
5//! It implements core `NumPy`-like features that are not available in the
6//! base ndarray crate.
7
8/// Re-export essential ndarray types and macros for convenience
9/// Use ::ndarray to refer to the external crate directly
10pub use ::ndarray::{
11    s, Array, ArrayBase, ArrayView, ArrayViewMut, Axis, Data, DataMut, Dim, Dimension, Ix1, Ix2,
12    Ix3, Ix4, Ix5, Ix6, IxDyn, OwnedRepr, RemoveAxis, ScalarOperand, ShapeBuilder, ShapeError,
13    SliceInfo, ViewRepr, Zip,
14};
15
16// Re-export the array! macro for convenient array creation
17// This addresses the common issue where users expect array! to be available in scirs2-core
18// instead of requiring import from scirs2_autograd::ndarray
19pub use ::ndarray::{arr1, arr2, array};
20
21// Re-export type aliases for common array sizes
22pub use ::ndarray::{ArcArray1, ArcArray2};
23pub use ::ndarray::{Array0, Array1, Array2, Array3, Array4, Array5, Array6, ArrayD};
24pub use ::ndarray::{
25    ArrayView0, ArrayView1, ArrayView2, ArrayView3, ArrayView4, ArrayView5, ArrayView6, ArrayViewD,
26};
27pub use ::ndarray::{
28    ArrayViewMut0, ArrayViewMut1, ArrayViewMut2, ArrayViewMut3, ArrayViewMut4, ArrayViewMut5,
29    ArrayViewMut6, ArrayViewMutD,
30};
31
32/// Advanced indexing operations (`NumPy`-like boolean masking, fancy indexing, etc.)
33pub mod indexing;
34
35/// Statistical functions for ndarray arrays (mean, median, variance, correlation, etc.)
36pub mod stats;
37
38/// Matrix operations (eye, diag, kron, etc.)
39pub mod matrix;
40
41/// Array manipulation operations (flip, roll, tile, repeat, etc.)
42pub mod manipulation;
43
44/// Array reduction operations with SIMD acceleration (argmin, argmax, etc.)
45pub mod reduction;
46
47/// Data preprocessing operations with SIMD acceleration (normalization, standardization)
48pub mod preprocessing;
49
50/// Element-wise mathematical operations with SIMD acceleration (abs, sqrt, etc.)
51pub mod elementwise;
52
53/// Random array generation with full ndarray-rand compatibility and scientific computing extensions
54#[cfg(feature = "random")]
55pub mod random;
56
57/// Reshape a 2D array to a new shape without copying data when possible
58///
59/// # Arguments
60///
61/// * `array` - The input array
62/// * `shape` - The new shape (rows, cols), which must be compatible with the original shape
63///
64/// # Returns
65///
66/// A reshaped array view of the input array if possible, or a new array otherwise
67///
68/// # Examples
69///
70/// ```
71/// use scirs2_core::ndarray::array;
72/// use scirs2_core::ndarray_ext::reshape_2d;
73///
74/// let a = array![[1, 2], [3, 4]];
75/// let b = reshape_2d(a.view(), (4, 1)).unwrap();
76/// assert_eq!(b.shape(), &[4, 1]);
77/// assert_eq!(b[[0, 0]], 1);
78/// assert_eq!(b[[3, 0]], 4);
79/// ```
80#[allow(dead_code)]
81pub fn reshape_2d<T>(
82    array: ArrayView<T, Ix2>,
83    shape: (usize, usize),
84) -> Result<Array<T, Ix2>, &'static str>
85where
86    T: Clone + Default,
87{
88    let (rows, cols) = shape;
89    let total_elements = rows * cols;
90
91    // Check if the new shape is compatible with the original shape
92    if total_elements != array.len() {
93        return Err("New shape dimensions must match the total number of elements");
94    }
95
96    // Create a new array with the specified shape
97    let mut result = Array::<T, Ix2>::default(shape);
98
99    // Fill the result array with elements from the input array
100    let flat_iter = array.iter();
101    for (i, val) in flat_iter.enumerate() {
102        let r = i / cols;
103        let c = i % cols;
104        result[[r, c]] = val.clone();
105    }
106
107    Ok(result)
108}
109
110/// Stack 2D arrays along a given axis
111///
112/// # Arguments
113///
114/// * `arrays` - A slice of 2D arrays to stack
115/// * `axis` - The axis along which to stack (0 for rows, 1 for columns)
116///
117/// # Returns
118///
119/// A new array containing the stacked arrays
120///
121/// # Examples
122///
123/// ```
124/// use ndarray::{array, Axis};
125/// use scirs2_core::ndarray_ext::stack_2d;
126///
127/// let a = array![[1, 2], [3, 4]];
128/// let b = array![[5, 6], [7, 8]];
129/// let c = stack_2d(&[a.view(), b.view()], 0).unwrap();
130/// assert_eq!(c.shape(), &[4, 2]);
131/// ```
132#[allow(dead_code)]
133pub fn stack_2d<T>(arrays: &[ArrayView<T, Ix2>], axis: usize) -> Result<Array<T, Ix2>, &'static str>
134where
135    T: Clone + Default,
136{
137    if arrays.is_empty() {
138        return Err("No _arrays provided for stacking");
139    }
140
141    // Validate that all _arrays have the same shape
142    let firstshape = arrays[0].shape();
143    for array in arrays.iter().skip(1) {
144        if array.shape() != firstshape {
145            return Err("All _arrays must have the same shape for stacking");
146        }
147    }
148
149    let (rows, cols) = (firstshape[0], firstshape[1]);
150
151    // Calculate the new shape
152    let (new_rows, new_cols) = match axis {
153        0 => (rows * arrays.len(), cols), // Stack vertically
154        1 => (rows, cols * arrays.len()), // Stack horizontally
155        _ => return Err("Axis must be 0 or 1 for 2D _arrays"),
156    };
157
158    // Create a new array to hold the stacked result
159    let mut result = Array::<T, Ix2>::default((new_rows, new_cols));
160
161    // Copy data from the input _arrays to the result
162    match axis {
163        0 => {
164            // Stack vertically (along rows)
165            for (array_idx, array) in arrays.iter().enumerate() {
166                let start_row = array_idx * rows;
167                for r in 0..rows {
168                    for c in 0..cols {
169                        result[[start_row + r, c]] = array[[r, c]].clone();
170                    }
171                }
172            }
173        }
174        1 => {
175            // Stack horizontally (along columns)
176            for (array_idx, array) in arrays.iter().enumerate() {
177                let start_col = array_idx * cols;
178                for r in 0..rows {
179                    for c in 0..cols {
180                        result[[r, start_col + c]] = array[[r, c]].clone();
181                    }
182                }
183            }
184        }
185        _ => unreachable!(),
186    }
187
188    Ok(result)
189}
190
191/// Swap axes (transpose) of a 2D array
192///
193/// # Arguments
194///
195/// * `array` - The input 2D array
196///
197/// # Returns
198///
199/// A view of the input array with the axes swapped
200///
201/// # Examples
202///
203/// ```ignore
204/// use crate::ndarray::array;
205/// use scirs2_core::ndarray_ext::transpose_2d;
206///
207/// let a = array![[1, 2, 3], [4, 5, 6]];
208/// let b = transpose_2d(a.view());
209/// assert_eq!(b.shape(), &[3, 2]);
210/// assert_eq!(b[[0, 0]], 1);
211/// assert_eq!(b[[0, 1]], 4);
212/// ```
213#[allow(dead_code)]
214pub fn transpose_2d<T>(array: ArrayView<T, Ix2>) -> Array<T, Ix2>
215where
216    T: Clone,
217{
218    array.t().to_owned()
219}
220
221/// Split a 2D array into multiple sub-arrays along a given axis
222///
223/// # Arguments
224///
225/// * `array` - The input 2D array to split
226/// * `indices` - Indices where the array should be split
227/// * `axis` - The axis along which to split (0 for rows, 1 for columns)
228///
229/// # Returns
230///
231/// A vector of arrays resulting from the split
232///
233/// # Examples
234///
235/// ```
236/// use scirs2_core::ndarray::array;
237/// use scirs2_core::ndarray_ext::split_2d;
238///
239/// let a = array![[1, 2, 3, 4], [5, 6, 7, 8]];
240/// let result = split_2d(a.view(), &[2], 1).unwrap();
241/// assert_eq!(result.len(), 2);
242/// assert_eq!(result[0].shape(), &[2, 2]);
243/// assert_eq!(result[1].shape(), &[2, 2]);
244/// ```
245#[allow(dead_code)]
246pub fn split_2d<T>(
247    array: ArrayView<T, Ix2>,
248    indices: &[usize],
249    axis: usize,
250) -> Result<Vec<Array<T, Ix2>>, &'static str>
251where
252    T: Clone + Default,
253{
254    if indices.is_empty() {
255        return Ok(vec![array.to_owned()]);
256    }
257
258    let (rows, cols) = (array.shape()[0], array.shape()[1]);
259    let axis_len = if axis == 0 { rows } else { cols };
260
261    // Validate indices
262    for &idx in indices {
263        if idx >= axis_len {
264            return Err("Split index out of bounds");
265        }
266    }
267
268    // Sort indices to ensure they're in ascending order
269    let mut sorted_indices = indices.to_vec();
270    sorted_indices.sort_unstable();
271
272    // Calculate the sub-array boundaries
273    let mut starts = vec![0];
274    starts.extend_from_slice(&sorted_indices);
275
276    let mut ends = sorted_indices.clone();
277    ends.push(axis_len);
278
279    // Create the split sub-arrays
280    let mut result = Vec::with_capacity(starts.len());
281
282    match axis {
283        0 => {
284            // Split along rows
285            for (&start, &end) in starts.iter().zip(ends.iter()) {
286                let sub_rows = end - start;
287                let mut sub_array = Array::<T, Ix2>::default((sub_rows, cols));
288
289                for r in 0..sub_rows {
290                    for c in 0..cols {
291                        sub_array[[r, c]] = array[[start + r, c]].clone();
292                    }
293                }
294
295                result.push(sub_array);
296            }
297        }
298        1 => {
299            // Split along columns
300            for (&start, &end) in starts.iter().zip(ends.iter()) {
301                let sub_cols = end - start;
302                let mut sub_array = Array::<T, Ix2>::default((rows, sub_cols));
303
304                for r in 0..rows {
305                    for c in 0..sub_cols {
306                        sub_array[[r, c]] = array[[r, start + c]].clone();
307                    }
308                }
309
310                result.push(sub_array);
311            }
312        }
313        _ => return Err("Axis must be 0 or 1 for 2D arrays"),
314    }
315
316    Ok(result)
317}
318
319/// Take elements from a 2D array along a given axis using indices from another array
320///
321/// # Arguments
322///
323/// * `array` - The input 2D array
324/// * `indices` - Array of indices to take
325/// * `axis` - The axis along which to take values (0 for rows, 1 for columns)
326///
327/// # Returns
328///
329/// An array of values at the specified indices along the given axis
330///
331/// # Examples
332///
333/// ```
334/// use scirs2_core::ndarray::array;
335/// use scirs2_core::ndarray_ext::take_2d;
336///
337/// let a = array![[1, 2, 3], [4, 5, 6]];
338/// let indices = array![0_usize, 2];
339/// let result = take_2d(a.view(), indices.view(), 1).unwrap();
340/// assert_eq!(result.shape(), &[2, 2]);
341/// assert_eq!(result[[0, 0]], 1);
342/// assert_eq!(result[[0, 1]], 3);
343/// ```
344#[allow(dead_code)]
345pub fn take_2d<T>(
346    array: ArrayView<T, Ix2>,
347    indices: ArrayView<usize, Ix1>,
348    axis: usize,
349) -> Result<Array<T, Ix2>, &'static str>
350where
351    T: Clone + Default,
352{
353    let (rows, cols) = (array.shape()[0], array.shape()[1]);
354    let axis_len = if axis == 0 { rows } else { cols };
355
356    // Check that indices are within bounds
357    for &idx in indices.iter() {
358        if idx >= axis_len {
359            return Err("Index out of bounds");
360        }
361    }
362
363    // Create the result array with the appropriate shape
364    let (result_rows, result_cols) = match axis {
365        0 => (indices.len(), cols),
366        1 => (rows, indices.len()),
367        _ => return Err("Axis must be 0 or 1 for 2D arrays"),
368    };
369
370    let mut result = Array::<T, Ix2>::default((result_rows, result_cols));
371
372    // Fill the result array
373    match axis {
374        0 => {
375            // Take along rows
376            for (i, &idx) in indices.iter().enumerate() {
377                for j in 0..cols {
378                    result[[i, j]] = array[[idx, j]].clone();
379                }
380            }
381        }
382        1 => {
383            // Take along columns
384            for i in 0..rows {
385                for (j, &idx) in indices.iter().enumerate() {
386                    result[[i, j]] = array[[i, idx]].clone();
387                }
388            }
389        }
390        _ => unreachable!(),
391    }
392
393    Ok(result)
394}
395
396/// Filter an array using a boolean mask
397///
398/// # Arguments
399///
400/// * `array` - The input array
401/// * `mask` - Boolean mask of the same shape as the array
402///
403/// # Returns
404///
405/// A 1D array containing the elements where the mask is true
406///
407/// # Examples
408///
409/// ```ignore
410/// use crate::ndarray::array;
411/// use scirs2_core::ndarray_ext::mask_select;
412///
413/// let a = array![[1, 2, 3], [4, 5, 6]];
414/// let mask = array![[true, false, true], [false, true, false]];
415/// let result = mask_select(a.view(), mask.view()).unwrap();
416/// assert_eq!(result.shape(), &[3]);
417/// assert_eq!(result[0], 1);
418/// assert_eq!(result[1], 3);
419/// assert_eq!(result[2], 5);
420/// ```
421#[allow(dead_code)]
422pub fn mask_select<T>(
423    array: ArrayView<T, Ix2>,
424    mask: ArrayView<bool, Ix2>,
425) -> Result<Array<T, Ix1>, &'static str>
426where
427    T: Clone + Default,
428{
429    // Check that the mask has the same shape as the array
430    if array.shape() != mask.shape() {
431        return Err("Mask shape must match array shape");
432    }
433
434    // Count the number of true values in the mask
435    let true_count = mask.iter().filter(|&&x| x).count();
436
437    // Create the result array
438    let mut result = Array::<T, Ix1>::default(true_count);
439
440    // Fill the result array with elements where the mask is true
441    let mut idx = 0;
442    for (val, &m) in array.iter().zip(mask.iter()) {
443        if m {
444            result[idx] = val.clone();
445            idx += 1;
446        }
447    }
448
449    Ok(result)
450}
451
452/// Index a 2D array with a list of index arrays (fancy indexing)
453///
454/// # Arguments
455///
456/// * `array` - The input 2D array
457/// * `row_indices` - Array of row indices
458/// * `col_indices` - Array of column indices
459///
460/// # Returns
461///
462/// A 1D array containing the elements at the specified indices
463///
464/// # Examples
465///
466/// ```ignore
467/// use crate::ndarray::array;
468/// use scirs2_core::ndarray_ext::fancy_index_2d;
469///
470/// let a = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]];
471/// let row_indices = array![0, 2];
472/// let col_indices = array![0, 1];
473/// let result = fancy_index_2d(a.view(), row_indices.view(), col_indices.view()).unwrap();
474/// assert_eq!(result.shape(), &[2]);
475/// assert_eq!(result[0], 1);
476/// assert_eq!(result[1], 8);
477/// ```
478#[allow(dead_code)]
479pub fn fancy_index_2d<T>(
480    array: ArrayView<T, Ix2>,
481    row_indices: ArrayView<usize, Ix1>,
482    col_indices: ArrayView<usize, Ix1>,
483) -> Result<Array<T, Ix1>, &'static str>
484where
485    T: Clone + Default,
486{
487    // Check that all index arrays have the same length
488    let result_size = row_indices.len();
489    if col_indices.len() != result_size {
490        return Err("Row and column index arrays must have the same length");
491    }
492
493    let (rows, cols) = (array.shape()[0], array.shape()[1]);
494
495    // Check that _indices are within bounds
496    for &idx in row_indices.iter() {
497        if idx >= rows {
498            return Err("Row index out of bounds");
499        }
500    }
501
502    for &idx in col_indices.iter() {
503        if idx >= cols {
504            return Err("Column index out of bounds");
505        }
506    }
507
508    // Create the result array
509    let mut result = Array::<T, Ix1>::default(result_size);
510
511    // Fill the result array
512    for i in 0..result_size {
513        let row = row_indices[i];
514        let col = col_indices[i];
515        result[i] = array[[row, col]].clone();
516    }
517
518    Ok(result)
519}
520
521/// Select elements from an array where a condition is true
522///
523/// # Arguments
524///
525/// * `array` - The input array
526/// * `condition` - A function that takes a reference to an element and returns a bool
527///
528/// # Returns
529///
530/// A 1D array containing the elements where the condition is true
531///
532/// # Examples
533///
534/// ```ignore
535/// use crate::ndarray::array;
536/// use scirs2_core::ndarray_ext::where_condition;
537///
538/// let a = array![[1, 2, 3], [4, 5, 6]];
539/// let result = where_condition(a.view(), |&x| x > 3).unwrap();
540/// assert_eq!(result.shape(), &[3]);
541/// assert_eq!(result[0], 4);
542/// assert_eq!(result[1], 5);
543/// assert_eq!(result[2], 6);
544/// ```
545#[allow(dead_code)]
546pub fn where_condition<T, F>(
547    array: ArrayView<T, Ix2>,
548    condition: F,
549) -> Result<Array<T, Ix1>, &'static str>
550where
551    T: Clone + Default,
552    F: Fn(&T) -> bool,
553{
554    // Build a boolean mask array based on the condition
555    let mask = array.map(condition);
556
557    // Use the mask_select function to select elements
558    mask_select(array, mask.view())
559}
560
561/// Check if two shapes are broadcast compatible
562///
563/// # Arguments
564///
565/// * `shape1` - First shape as a slice
566/// * `shape2` - Second shape as a slice
567///
568/// # Returns
569///
570/// `true` if the shapes are broadcast compatible, `false` otherwise
571///
572/// # Examples
573///
574/// ```ignore
575/// use scirs2_core::ndarray_ext::is_broadcast_compatible;
576///
577/// assert!(is_broadcast_compatible(&[2, 3], &[3]));
578/// // This example has dimensions that don't match (5 vs 3 in dimension 0)
579/// // and aren't broadcasting compatible (neither is 1)
580/// assert!(!is_broadcast_compatible(&[5, 1, 4], &[3, 1, 1]));
581/// assert!(!is_broadcast_compatible(&[2, 3], &[4]));
582/// ```
583#[allow(dead_code)]
584pub fn is_broadcast_compatible(shape1: &[usize], shape2: &[usize]) -> bool {
585    // Align shapes to have the same dimensionality by prepending with 1s
586    let max_dim = shape1.len().max(shape2.len());
587
588    // Fill in 1s for missing dimensions
589    let get_dim = |shape: &[usize], i: usize| -> usize {
590        let offset = max_dim - shape.len();
591        if i < offset {
592            1 // Implicit dimension of size 1
593        } else {
594            shape[i - offset]
595        }
596    };
597
598    // Check broadcasting rules for each dimension
599    for i in 0..max_dim {
600        let dim1 = get_dim(shape1, i);
601        let dim2 = get_dim(shape2, i);
602
603        // Dimensions must either be the same or one of them must be 1
604        if dim1 != dim2 && dim1 != 1 && dim2 != 1 {
605            return false;
606        }
607    }
608
609    true
610}
611
612/// Calculate the broadcasted shape from two input shapes
613///
614/// # Arguments
615///
616/// * `shape1` - First shape as a slice
617/// * `shape2` - Second shape as a slice
618///
619/// # Returns
620///
621/// The broadcasted shape as a `Vec<usize>`, or `None` if the shapes are incompatible
622#[allow(dead_code)]
623pub fn broadcastshape(shape1: &[usize], shape2: &[usize]) -> Option<Vec<usize>> {
624    if !is_broadcast_compatible(shape1, shape2) {
625        return None;
626    }
627
628    // Align shapes to have the same dimensionality
629    let max_dim = shape1.len().max(shape2.len());
630    let mut result = Vec::with_capacity(max_dim);
631
632    // Fill in 1s for missing dimensions
633    let get_dim = |shape: &[usize], i: usize| -> usize {
634        let offset = max_dim - shape.len();
635        if i < offset {
636            1 // Implicit dimension of size 1
637        } else {
638            shape[i - offset]
639        }
640    };
641
642    // Calculate the broadcasted shape
643    for i in 0..max_dim {
644        let dim1 = get_dim(shape1, i);
645        let dim2 = get_dim(shape2, i);
646
647        // The broadcasted dimension is the maximum of the two
648        result.push(dim1.max(dim2));
649    }
650
651    Some(result)
652}
653
654/// Broadcast a 1D array to a 2D shape by repeating it along the specified axis
655///
656/// # Arguments
657///
658/// * `array` - The input 1D array
659/// * `repeats` - Number of times to repeat the array
660/// * `axis` - Axis along which to repeat (0 for rows, 1 for columns)
661///
662/// # Returns
663///
664/// A 2D array with the input array repeated along the specified axis
665///
666/// # Examples
667///
668/// ```ignore
669/// use crate::ndarray::array;
670/// use scirs2_core::ndarray_ext::broadcast_1d_to_2d;
671///
672/// let a = array![1, 2, 3];
673/// let b = broadcast_1d_to_2d(a.view(), 2, 0).unwrap();
674/// assert_eq!(b.shape(), &[2, 3]);
675/// assert_eq!(b[[0, 0]], 1);
676/// assert_eq!(b[[1, 0]], 1);
677/// ```
678#[allow(dead_code)]
679pub fn broadcast_1d_to_2d<T>(
680    array: ArrayView<T, Ix1>,
681    repeats: usize,
682    axis: usize,
683) -> Result<Array<T, Ix2>, &'static str>
684where
685    T: Clone + Default,
686{
687    let len = array.len();
688
689    // Create the result array with the appropriate shape
690    let (rows, cols) = match axis {
691        0 => (repeats, len), // Broadcast along rows
692        1 => (len, repeats), // Broadcast along columns
693        _ => return Err("Axis must be 0 or 1"),
694    };
695
696    let mut result = Array::<T, Ix2>::default((rows, cols));
697
698    // Fill the result array
699    match axis {
700        0 => {
701            // Broadcast along rows
702            for i in 0..repeats {
703                for j in 0..len {
704                    result[[i, j]] = array[j].clone();
705                }
706            }
707        }
708        1 => {
709            // Broadcast along columns
710            for i in 0..len {
711                for j in 0..repeats {
712                    result[[i, j]] = array[i].clone();
713                }
714            }
715        }
716        _ => unreachable!(),
717    }
718
719    Ok(result)
720}
721
722/// Apply an element-wise binary operation to two arrays with broadcasting
723///
724/// # Arguments
725///
726/// * `a` - First array (2D)
727/// * `b` - Second array (can be 1D or 2D)
728/// * `op` - Binary operation to apply to each pair of elements
729///
730/// # Returns
731///
732/// A 2D array containing the result of the operation applied element-wise
733///
734/// # Examples
735///
736/// ```ignore
737/// use crate::ndarray::array;
738/// use scirs2_core::ndarray_ext::broadcast_apply;
739///
740/// let a = array![[1, 2, 3], [4, 5, 6]];
741/// let b = array![10, 20, 30];
742/// let result = broadcast_apply(a.view(), b.view(), |x, y| x + y).unwrap();
743/// assert_eq!(result.shape(), &[2, 3]);
744/// assert_eq!(result[[0, 0]], 11);
745/// assert_eq!(result[[1, 2]], 36);
746/// ```
747#[allow(dead_code)]
748pub fn broadcast_apply<T, R, F>(
749    a: ArrayView<T, Ix2>,
750    b: ArrayView<T, Ix1>,
751    op: F,
752) -> Result<Array<R, Ix2>, &'static str>
753where
754    T: Clone + Default,
755    R: Clone + Default,
756    F: Fn(&T, &T) -> R,
757{
758    let (a_rows, a_cols) = (a.shape()[0], a.shape()[1]);
759    let b_len = b.len();
760
761    // Check that the arrays are broadcast compatible
762    if a_cols != b_len {
763        return Err("Arrays are not broadcast compatible");
764    }
765
766    // Create the result array
767    let mut result = Array::<R, Ix2>::default((a_rows, a_cols));
768
769    // Apply the operation element-wise
770    for i in 0..a_rows {
771        for j in 0..a_cols {
772            result[[i, j]] = op(&a[[i, j]], &b[j]);
773        }
774    }
775
776    Ok(result)
777}
778
779#[cfg(test)]
780mod tests {
781    use super::*;
782    use crate::ndarray::array;
783
784    #[test]
785    fn test_reshape_2d() {
786        let a = array![[1, 2], [3, 4]];
787        let b = reshape_2d(a.view(), (4, 1)).unwrap();
788        assert_eq!(b.shape(), &[4, 1]);
789        assert_eq!(b[[0, 0]], 1);
790        assert_eq!(b[[1, 0]], 2);
791        assert_eq!(b[[2, 0]], 3);
792        assert_eq!(b[[3, 0]], 4);
793
794        // Test invalid shape
795        let result = reshape_2d(a.view(), (3, 1));
796        assert!(result.is_err());
797    }
798
799    #[test]
800    fn test_stack_2d() {
801        let a = array![[1, 2], [3, 4]];
802        let b = array![[5, 6], [7, 8]];
803
804        // Stack vertically (along axis 0)
805        let c = stack_2d(&[a.view(), b.view()], 0).unwrap();
806        assert_eq!(c.shape(), &[4, 2]);
807        assert_eq!(c[[0, 0]], 1);
808        assert_eq!(c[[1, 0]], 3);
809        assert_eq!(c[[2, 0]], 5);
810        assert_eq!(c[[3, 0]], 7);
811
812        // Stack horizontally (along axis 1)
813        let d = stack_2d(&[a.view(), b.view()], 1).unwrap();
814        assert_eq!(d.shape(), &[2, 4]);
815        assert_eq!(d[[0, 0]], 1);
816        assert_eq!(d[[0, 1]], 2);
817        assert_eq!(d[[0, 2]], 5);
818        assert_eq!(d[[0, 3]], 6);
819    }
820
821    #[test]
822    fn test_transpose_2d() {
823        let a = array![[1, 2, 3], [4, 5, 6]];
824        let b = transpose_2d(a.view());
825        assert_eq!(b.shape(), &[3, 2]);
826        assert_eq!(b[[0, 0]], 1);
827        assert_eq!(b[[0, 1]], 4);
828        assert_eq!(b[[1, 0]], 2);
829        assert_eq!(b[[2, 1]], 6);
830    }
831
832    #[test]
833    fn test_split_2d() {
834        let a = array![[1, 2, 3, 4], [5, 6, 7, 8]];
835
836        // Split along columns at index 2
837        let result = split_2d(a.view(), &[2], 1).unwrap();
838        assert_eq!(result.len(), 2);
839        assert_eq!(result[0].shape(), &[2, 2]);
840        assert_eq!(result[0][[0, 0]], 1);
841        assert_eq!(result[0][[0, 1]], 2);
842        assert_eq!(result[0][[1, 0]], 5);
843        assert_eq!(result[0][[1, 1]], 6);
844        assert_eq!(result[1].shape(), &[2, 2]);
845        assert_eq!(result[1][[0, 0]], 3);
846        assert_eq!(result[1][[0, 1]], 4);
847        assert_eq!(result[1][[1, 0]], 7);
848        assert_eq!(result[1][[1, 1]], 8);
849
850        // Split along rows at index 1
851        let result = split_2d(a.view(), &[1], 0).unwrap();
852        assert_eq!(result.len(), 2);
853        assert_eq!(result[0].shape(), &[1, 4]);
854        assert_eq!(result[1].shape(), &[1, 4]);
855    }
856
857    #[test]
858    fn test_take_2d() {
859        let a = array![[1, 2, 3], [4, 5, 6]];
860        let indices = array![0, 2];
861
862        // Take along columns
863        let result = take_2d(a.view(), indices.view(), 1).unwrap();
864        assert_eq!(result.shape(), &[2, 2]);
865        assert_eq!(result[[0, 0]], 1);
866        assert_eq!(result[[0, 1]], 3);
867        assert_eq!(result[[1, 0]], 4);
868        assert_eq!(result[[1, 1]], 6);
869    }
870
871    #[test]
872    fn test_mask_select() {
873        let a = array![[1, 2, 3], [4, 5, 6]];
874        let mask = array![[true, false, true], [false, true, false]];
875
876        let result = mask_select(a.view(), mask.view()).unwrap();
877        assert_eq!(result.shape(), &[3]);
878        assert_eq!(result[0], 1);
879        assert_eq!(result[1], 3);
880        assert_eq!(result[2], 5);
881    }
882
883    #[test]
884    fn test_fancy_index_2d() {
885        let a = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]];
886        let row_indices = array![0, 2];
887        let col_indices = array![0, 1];
888
889        let result = fancy_index_2d(a.view(), row_indices.view(), col_indices.view()).unwrap();
890        assert_eq!(result.shape(), &[2]);
891        assert_eq!(result[0], 1);
892        assert_eq!(result[1], 8);
893    }
894
895    #[test]
896    fn test_where_condition() {
897        let a = array![[1, 2, 3], [4, 5, 6]];
898        let result = where_condition(a.view(), |&x| x > 3).unwrap();
899        assert_eq!(result.shape(), &[3]);
900        assert_eq!(result[0], 4);
901        assert_eq!(result[1], 5);
902        assert_eq!(result[2], 6);
903    }
904
905    #[test]
906    fn test_broadcast_1d_to_2d() {
907        let a = array![1, 2, 3];
908
909        // Broadcast along rows (axis 0)
910        let b = broadcast_1d_to_2d(a.view(), 2, 0).unwrap();
911        assert_eq!(b.shape(), &[2, 3]);
912        assert_eq!(b[[0, 0]], 1);
913        assert_eq!(b[[0, 1]], 2);
914        assert_eq!(b[[1, 0]], 1);
915        assert_eq!(b[[1, 2]], 3);
916
917        // Broadcast along columns (axis 1)
918        let c = broadcast_1d_to_2d(a.view(), 2, 1).unwrap();
919        assert_eq!(c.shape(), &[3, 2]);
920        assert_eq!(c[[0, 0]], 1);
921        assert_eq!(c[[0, 1]], 1);
922        assert_eq!(c[[1, 0]], 2);
923        assert_eq!(c[[2, 1]], 3);
924    }
925
926    #[test]
927    fn test_broadcast_apply() {
928        let a = array![[1, 2, 3], [4, 5, 6]];
929        let b = array![10, 20, 30];
930
931        let result = broadcast_apply(a.view(), b.view(), |x, y| x + y).unwrap();
932        assert_eq!(result.shape(), &[2, 3]);
933        assert_eq!(result[[0, 0]], 11);
934        assert_eq!(result[[0, 1]], 22);
935        assert_eq!(result[[0, 2]], 33);
936        assert_eq!(result[[1, 0]], 14);
937        assert_eq!(result[[1, 1]], 25);
938        assert_eq!(result[[1, 2]], 36);
939
940        let result = broadcast_apply(a.view(), b.view(), |x, y| x * y).unwrap();
941        assert_eq!(result.shape(), &[2, 3]);
942        assert_eq!(result[[0, 0]], 10);
943        assert_eq!(result[[0, 1]], 40);
944        assert_eq!(result[[0, 2]], 90);
945        assert_eq!(result[[1, 0]], 40);
946        assert_eq!(result[[1, 1]], 100);
947        assert_eq!(result[[1, 2]], 180);
948    }
949}