scirs2_core/ndarray_ext/
mod.rs

1//! Extended ndarray operations for scientific computing
2//!
3//! This module provides additional functionality for ndarray to support
4//! the advanced array operations necessary for a complete SciPy port.
5//! It implements core `NumPy`-like features that are not available in the
6//! base ndarray crate.
7
8/// Re-export essential ndarray types for convenience
9pub use ndarray::{
10    Array, ArrayBase, ArrayView, ArrayViewMut, Axis, Dim, Dimension, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6,
11    IxDyn, OwnedRepr, ShapeBuilder, SliceInfo, ViewRepr,
12};
13
14/// Advanced indexing operations (`NumPy`-like boolean masking, fancy indexing, etc.)
15pub mod indexing;
16
17/// Statistical functions for ndarray arrays (mean, median, variance, correlation, etc.)
18pub mod stats;
19
20/// Matrix operations (eye, diag, kron, etc.)
21pub mod matrix;
22
23/// Array manipulation operations (flip, roll, tile, repeat, etc.)
24pub mod manipulation;
25
26/// Reshape a 2D array to a new shape without copying data when possible
27///
28/// # Arguments
29///
30/// * `array` - The input array
31/// * `shape` - The new shape (rows, cols), which must be compatible with the original shape
32///
33/// # Returns
34///
35/// A reshaped array view of the input array if possible, or a new array otherwise
36///
37/// # Examples
38///
39/// ```
40/// use ndarray::array;
41/// use scirs2_core::ndarray_ext::reshape_2d;
42///
43/// let a = array![[1, 2], [3, 4]];
44/// let b = reshape_2d(a.view(), (4, 1)).unwrap();
45/// assert_eq!(b.shape(), &[4, 1]);
46/// assert_eq!(b[[0, 0]], 1);
47/// assert_eq!(b[[3, 0]], 4);
48/// ```
49#[allow(dead_code)]
50pub fn reshape_2d<T>(
51    array: ArrayView<T, Ix2>,
52    shape: (usize, usize),
53) -> Result<Array<T, Ix2>, &'static str>
54where
55    T: Clone + Default,
56{
57    let (rows, cols) = shape;
58    let total_elements = rows * cols;
59
60    // Check if the new shape is compatible with the original shape
61    if total_elements != array.len() {
62        return Err("New shape dimensions must match the total number of elements");
63    }
64
65    // Create a new array with the specified shape
66    let mut result = Array::<T, Ix2>::default(shape);
67
68    // Fill the result array with elements from the input array
69    let flat_iter = array.iter();
70    for (i, val) in flat_iter.enumerate() {
71        let r = i / cols;
72        let c = i % cols;
73        result[[r, c]] = val.clone();
74    }
75
76    Ok(result)
77}
78
79/// Stack 2D arrays along a given axis
80///
81/// # Arguments
82///
83/// * `arrays` - A slice of 2D arrays to stack
84/// * `axis` - The axis along which to stack (0 for rows, 1 for columns)
85///
86/// # Returns
87///
88/// A new array containing the stacked arrays
89///
90/// # Examples
91///
92/// ```
93/// use ndarray::{array, Axis};
94/// use scirs2_core::ndarray_ext::stack_2d;
95///
96/// let a = array![[1, 2], [3, 4]];
97/// let b = array![[5, 6], [7, 8]];
98/// let c = stack_2d(&[a.view(), b.view()], 0).unwrap();
99/// assert_eq!(c.shape(), &[4, 2]);
100/// ```
101#[allow(dead_code)]
102pub fn stack_2d<T>(arrays: &[ArrayView<T, Ix2>], axis: usize) -> Result<Array<T, Ix2>, &'static str>
103where
104    T: Clone + Default,
105{
106    if arrays.is_empty() {
107        return Err("No _arrays provided for stacking");
108    }
109
110    // Validate that all _arrays have the same shape
111    let firstshape = arrays[0].shape();
112    for array in arrays.iter().skip(1) {
113        if array.shape() != firstshape {
114            return Err("All _arrays must have the same shape for stacking");
115        }
116    }
117
118    let (rows, cols) = (firstshape[0], firstshape[1]);
119
120    // Calculate the new shape
121    let (new_rows, new_cols) = match axis {
122        0 => (rows * arrays.len(), cols), // Stack vertically
123        1 => (rows, cols * arrays.len()), // Stack horizontally
124        _ => return Err("Axis must be 0 or 1 for 2D _arrays"),
125    };
126
127    // Create a new array to hold the stacked result
128    let mut result = Array::<T, Ix2>::default((new_rows, new_cols));
129
130    // Copy data from the input _arrays to the result
131    match axis {
132        0 => {
133            // Stack vertically (along rows)
134            for (array_idx, array) in arrays.iter().enumerate() {
135                let start_row = array_idx * rows;
136                for r in 0..rows {
137                    for c in 0..cols {
138                        result[[start_row + r, c]] = array[[r, c]].clone();
139                    }
140                }
141            }
142        }
143        1 => {
144            // Stack horizontally (along columns)
145            for (array_idx, array) in arrays.iter().enumerate() {
146                let start_col = array_idx * cols;
147                for r in 0..rows {
148                    for c in 0..cols {
149                        result[[r, start_col + c]] = array[[r, c]].clone();
150                    }
151                }
152            }
153        }
154        _ => unreachable!(),
155    }
156
157    Ok(result)
158}
159
160/// Swap axes (transpose) of a 2D array
161///
162/// # Arguments
163///
164/// * `array` - The input 2D array
165///
166/// # Returns
167///
168/// A view of the input array with the axes swapped
169///
170/// # Examples
171///
172/// ```
173/// use ndarray::array;
174/// use scirs2_core::ndarray_ext::transpose_2d;
175///
176/// let a = array![[1, 2, 3], [4, 5, 6]];
177/// let b = transpose_2d(a.view());
178/// assert_eq!(b.shape(), &[3, 2]);
179/// assert_eq!(b[[0, 0]], 1);
180/// assert_eq!(b[[0, 1]], 4);
181/// ```
182#[allow(dead_code)]
183pub fn transpose_2d<T>(array: ArrayView<T, Ix2>) -> Array<T, Ix2>
184where
185    T: Clone,
186{
187    array.t().to_owned()
188}
189
190/// Split a 2D array into multiple sub-arrays along a given axis
191///
192/// # Arguments
193///
194/// * `array` - The input 2D array to split
195/// * `indices` - Indices where the array should be split
196/// * `axis` - The axis along which to split (0 for rows, 1 for columns)
197///
198/// # Returns
199///
200/// A vector of arrays resulting from the split
201///
202/// # Examples
203///
204/// ```
205/// use ndarray::array;
206/// use scirs2_core::ndarray_ext::split_2d;
207///
208/// let a = array![[1, 2, 3, 4], [5, 6, 7, 8]];
209/// let result = split_2d(a.view(), &[2], 1).unwrap();
210/// assert_eq!(result.len(), 2);
211/// assert_eq!(result[0].shape(), &[2, 2]);
212/// assert_eq!(result[1].shape(), &[2, 2]);
213/// ```
214#[allow(dead_code)]
215pub fn split_2d<T>(
216    array: ArrayView<T, Ix2>,
217    indices: &[usize],
218    axis: usize,
219) -> Result<Vec<Array<T, Ix2>>, &'static str>
220where
221    T: Clone + Default,
222{
223    if indices.is_empty() {
224        return Ok(vec![array.to_owned()]);
225    }
226
227    let (rows, cols) = (array.shape()[0], array.shape()[1]);
228    let axis_len = if axis == 0 { rows } else { cols };
229
230    // Validate indices
231    for &idx in indices {
232        if idx >= axis_len {
233            return Err("Split index out of bounds");
234        }
235    }
236
237    // Sort indices to ensure they're in ascending order
238    let mut sorted_indices = indices.to_vec();
239    sorted_indices.sort_unstable();
240
241    // Calculate the sub-array boundaries
242    let mut starts = vec![0];
243    starts.extend_from_slice(&sorted_indices);
244
245    let mut ends = sorted_indices.clone();
246    ends.push(axis_len);
247
248    // Create the split sub-arrays
249    let mut result = Vec::with_capacity(starts.len());
250
251    match axis {
252        0 => {
253            // Split along rows
254            for (&start, &end) in starts.iter().zip(ends.iter()) {
255                let sub_rows = end - start;
256                let mut sub_array = Array::<T, Ix2>::default((sub_rows, cols));
257
258                for r in 0..sub_rows {
259                    for c in 0..cols {
260                        sub_array[[r, c]] = array[[start + r, c]].clone();
261                    }
262                }
263
264                result.push(sub_array);
265            }
266        }
267        1 => {
268            // Split along columns
269            for (&start, &end) in starts.iter().zip(ends.iter()) {
270                let sub_cols = end - start;
271                let mut sub_array = Array::<T, Ix2>::default((rows, sub_cols));
272
273                for r in 0..rows {
274                    for c in 0..sub_cols {
275                        sub_array[[r, c]] = array[[r, start + c]].clone();
276                    }
277                }
278
279                result.push(sub_array);
280            }
281        }
282        _ => return Err("Axis must be 0 or 1 for 2D arrays"),
283    }
284
285    Ok(result)
286}
287
288/// Take elements from a 2D array along a given axis using indices from another array
289///
290/// # Arguments
291///
292/// * `array` - The input 2D array
293/// * `indices` - Array of indices to take
294/// * `axis` - The axis along which to take values (0 for rows, 1 for columns)
295///
296/// # Returns
297///
298/// An array of values at the specified indices along the given axis
299///
300/// # Examples
301///
302/// ```
303/// use ndarray::array;
304/// use scirs2_core::ndarray_ext::take_2d;
305///
306/// let a = array![[1, 2, 3], [4, 5, 6]];
307/// let indices = array![0, 2];
308/// let result = take_2d(a.view(), indices.view(), 1).unwrap();
309/// assert_eq!(result.shape(), &[2, 2]);
310/// assert_eq!(result[[0, 0]], 1);
311/// assert_eq!(result[[0, 1]], 3);
312/// ```
313#[allow(dead_code)]
314pub fn take_2d<T>(
315    array: ArrayView<T, Ix2>,
316    indices: ArrayView<usize, Ix1>,
317    axis: usize,
318) -> Result<Array<T, Ix2>, &'static str>
319where
320    T: Clone + Default,
321{
322    let (rows, cols) = (array.shape()[0], array.shape()[1]);
323    let axis_len = if axis == 0 { rows } else { cols };
324
325    // Check that indices are within bounds
326    for &idx in indices.iter() {
327        if idx >= axis_len {
328            return Err("Index out of bounds");
329        }
330    }
331
332    // Create the result array with the appropriate shape
333    let (result_rows, result_cols) = match axis {
334        0 => (indices.len(), cols),
335        1 => (rows, indices.len()),
336        _ => return Err("Axis must be 0 or 1 for 2D arrays"),
337    };
338
339    let mut result = Array::<T, Ix2>::default((result_rows, result_cols));
340
341    // Fill the result array
342    match axis {
343        0 => {
344            // Take along rows
345            for (i, &idx) in indices.iter().enumerate() {
346                for j in 0..cols {
347                    result[[i, j]] = array[[idx, j]].clone();
348                }
349            }
350        }
351        1 => {
352            // Take along columns
353            for i in 0..rows {
354                for (j, &idx) in indices.iter().enumerate() {
355                    result[[i, j]] = array[[i, idx]].clone();
356                }
357            }
358        }
359        _ => unreachable!(),
360    }
361
362    Ok(result)
363}
364
365/// Filter an array using a boolean mask
366///
367/// # Arguments
368///
369/// * `array` - The input array
370/// * `mask` - Boolean mask of the same shape as the array
371///
372/// # Returns
373///
374/// A 1D array containing the elements where the mask is true
375///
376/// # Examples
377///
378/// ```
379/// use ndarray::array;
380/// use scirs2_core::ndarray_ext::mask_select;
381///
382/// let a = array![[1, 2, 3], [4, 5, 6]];
383/// let mask = array![[true, false, true], [false, true, false]];
384/// let result = mask_select(a.view(), mask.view()).unwrap();
385/// assert_eq!(result.shape(), &[3]);
386/// assert_eq!(result[0], 1);
387/// assert_eq!(result[1], 3);
388/// assert_eq!(result[2], 5);
389/// ```
390#[allow(dead_code)]
391pub fn mask_select<T>(
392    array: ArrayView<T, Ix2>,
393    mask: ArrayView<bool, Ix2>,
394) -> Result<Array<T, Ix1>, &'static str>
395where
396    T: Clone + Default,
397{
398    // Check that the mask has the same shape as the array
399    if array.shape() != mask.shape() {
400        return Err("Mask shape must match array shape");
401    }
402
403    // Count the number of true values in the mask
404    let true_count = mask.iter().filter(|&&x| x).count();
405
406    // Create the result array
407    let mut result = Array::<T, Ix1>::default(true_count);
408
409    // Fill the result array with elements where the mask is true
410    let mut idx = 0;
411    for (val, &m) in array.iter().zip(mask.iter()) {
412        if m {
413            result[idx] = val.clone();
414            idx += 1;
415        }
416    }
417
418    Ok(result)
419}
420
421/// Index a 2D array with a list of index arrays (fancy indexing)
422///
423/// # Arguments
424///
425/// * `array` - The input 2D array
426/// * `row_indices` - Array of row indices
427/// * `col_indices` - Array of column indices
428///
429/// # Returns
430///
431/// A 1D array containing the elements at the specified indices
432///
433/// # Examples
434///
435/// ```
436/// use ndarray::array;
437/// use scirs2_core::ndarray_ext::fancy_index_2d;
438///
439/// let a = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]];
440/// let row_indices = array![0, 2];
441/// let col_indices = array![0, 1];
442/// let result = fancy_index_2d(a.view(), row_indices.view(), col_indices.view()).unwrap();
443/// assert_eq!(result.shape(), &[2]);
444/// assert_eq!(result[0], 1);
445/// assert_eq!(result[1], 8);
446/// ```
447#[allow(dead_code)]
448pub fn fancy_index_2d<T>(
449    array: ArrayView<T, Ix2>,
450    row_indices: ArrayView<usize, Ix1>,
451    col_indices: ArrayView<usize, Ix1>,
452) -> Result<Array<T, Ix1>, &'static str>
453where
454    T: Clone + Default,
455{
456    // Check that all index arrays have the same length
457    let result_size = row_indices.len();
458    if col_indices.len() != result_size {
459        return Err("Row and column index arrays must have the same length");
460    }
461
462    let (rows, cols) = (array.shape()[0], array.shape()[1]);
463
464    // Check that _indices are within bounds
465    for &idx in row_indices.iter() {
466        if idx >= rows {
467            return Err("Row index out of bounds");
468        }
469    }
470
471    for &idx in col_indices.iter() {
472        if idx >= cols {
473            return Err("Column index out of bounds");
474        }
475    }
476
477    // Create the result array
478    let mut result = Array::<T, Ix1>::default(result_size);
479
480    // Fill the result array
481    for i in 0..result_size {
482        let row = row_indices[i];
483        let col = col_indices[i];
484        result[i] = array[[row, col]].clone();
485    }
486
487    Ok(result)
488}
489
490/// Select elements from an array where a condition is true
491///
492/// # Arguments
493///
494/// * `array` - The input array
495/// * `condition` - A function that takes a reference to an element and returns a bool
496///
497/// # Returns
498///
499/// A 1D array containing the elements where the condition is true
500///
501/// # Examples
502///
503/// ```
504/// use ndarray::array;
505/// use scirs2_core::ndarray_ext::where_condition;
506///
507/// let a = array![[1, 2, 3], [4, 5, 6]];
508/// let result = where_condition(a.view(), |&x| x > 3).unwrap();
509/// assert_eq!(result.shape(), &[3]);
510/// assert_eq!(result[0], 4);
511/// assert_eq!(result[1], 5);
512/// assert_eq!(result[2], 6);
513/// ```
514#[allow(dead_code)]
515pub fn where_condition<T, F>(
516    array: ArrayView<T, Ix2>,
517    condition: F,
518) -> Result<Array<T, Ix1>, &'static str>
519where
520    T: Clone + Default,
521    F: Fn(&T) -> bool,
522{
523    // Build a boolean mask array based on the condition
524    let mask = array.map(condition);
525
526    // Use the mask_select function to select elements
527    mask_select(array, mask.view())
528}
529
530/// Check if two shapes are broadcast compatible
531///
532/// # Arguments
533///
534/// * `shape1` - First shape as a slice
535/// * `shape2` - Second shape as a slice
536///
537/// # Returns
538///
539/// `true` if the shapes are broadcast compatible, `false` otherwise
540///
541/// # Examples
542///
543/// ```
544/// use scirs2_core::ndarray_ext::is_broadcast_compatible;
545///
546/// assert!(is_broadcast_compatible(&[2, 3], &[3]));
547/// // This example has dimensions that don't match (5 vs 3 in dimension 0)
548/// // and aren't broadcasting compatible (neither is 1)
549/// assert!(!is_broadcast_compatible(&[5, 1, 4], &[3, 1, 1]));
550/// assert!(!is_broadcast_compatible(&[2, 3], &[4]));
551/// ```
552#[allow(dead_code)]
553pub fn is_broadcast_compatible(shape1: &[usize], shape2: &[usize]) -> bool {
554    // Align shapes to have the same dimensionality by prepending with 1s
555    let max_dim = shape1.len().max(shape2.len());
556
557    // Fill in 1s for missing dimensions
558    let get_dim = |shape: &[usize], i: usize| -> usize {
559        let offset = max_dim - shape.len();
560        if i < offset {
561            1 // Implicit dimension of size 1
562        } else {
563            shape[i - offset]
564        }
565    };
566
567    // Check broadcasting rules for each dimension
568    for i in 0..max_dim {
569        let dim1 = get_dim(shape1, i);
570        let dim2 = get_dim(shape2, i);
571
572        // Dimensions must either be the same or one of them must be 1
573        if dim1 != dim2 && dim1 != 1 && dim2 != 1 {
574            return false;
575        }
576    }
577
578    true
579}
580
581/// Calculate the broadcasted shape from two input shapes
582///
583/// # Arguments
584///
585/// * `shape1` - First shape as a slice
586/// * `shape2` - Second shape as a slice
587///
588/// # Returns
589///
590/// The broadcasted shape as a `Vec<usize>`, or `None` if the shapes are incompatible
591#[allow(dead_code)]
592pub fn broadcastshape(shape1: &[usize], shape2: &[usize]) -> Option<Vec<usize>> {
593    if !is_broadcast_compatible(shape1, shape2) {
594        return None;
595    }
596
597    // Align shapes to have the same dimensionality
598    let max_dim = shape1.len().max(shape2.len());
599    let mut result = Vec::with_capacity(max_dim);
600
601    // Fill in 1s for missing dimensions
602    let get_dim = |shape: &[usize], i: usize| -> usize {
603        let offset = max_dim - shape.len();
604        if i < offset {
605            1 // Implicit dimension of size 1
606        } else {
607            shape[i - offset]
608        }
609    };
610
611    // Calculate the broadcasted shape
612    for i in 0..max_dim {
613        let dim1 = get_dim(shape1, i);
614        let dim2 = get_dim(shape2, i);
615
616        // The broadcasted dimension is the maximum of the two
617        result.push(dim1.max(dim2));
618    }
619
620    Some(result)
621}
622
623/// Broadcast a 1D array to a 2D shape by repeating it along the specified axis
624///
625/// # Arguments
626///
627/// * `array` - The input 1D array
628/// * `repeats` - Number of times to repeat the array
629/// * `axis` - Axis along which to repeat (0 for rows, 1 for columns)
630///
631/// # Returns
632///
633/// A 2D array with the input array repeated along the specified axis
634///
635/// # Examples
636///
637/// ```
638/// use ndarray::array;
639/// use scirs2_core::ndarray_ext::broadcast_1d_to_2d;
640///
641/// let a = array![1, 2, 3];
642/// let b = broadcast_1d_to_2d(a.view(), 2, 0).unwrap();
643/// assert_eq!(b.shape(), &[2, 3]);
644/// assert_eq!(b[[0, 0]], 1);
645/// assert_eq!(b[[1, 0]], 1);
646/// ```
647#[allow(dead_code)]
648pub fn broadcast_1d_to_2d<T>(
649    array: ArrayView<T, Ix1>,
650    repeats: usize,
651    axis: usize,
652) -> Result<Array<T, Ix2>, &'static str>
653where
654    T: Clone + Default,
655{
656    let len = array.len();
657
658    // Create the result array with the appropriate shape
659    let (rows, cols) = match axis {
660        0 => (repeats, len), // Broadcast along rows
661        1 => (len, repeats), // Broadcast along columns
662        _ => return Err("Axis must be 0 or 1"),
663    };
664
665    let mut result = Array::<T, Ix2>::default((rows, cols));
666
667    // Fill the result array
668    match axis {
669        0 => {
670            // Broadcast along rows
671            for i in 0..repeats {
672                for j in 0..len {
673                    result[[i, j]] = array[j].clone();
674                }
675            }
676        }
677        1 => {
678            // Broadcast along columns
679            for i in 0..len {
680                for j in 0..repeats {
681                    result[[i, j]] = array[i].clone();
682                }
683            }
684        }
685        _ => unreachable!(),
686    }
687
688    Ok(result)
689}
690
691/// Apply an element-wise binary operation to two arrays with broadcasting
692///
693/// # Arguments
694///
695/// * `a` - First array (2D)
696/// * `b` - Second array (can be 1D or 2D)
697/// * `op` - Binary operation to apply to each pair of elements
698///
699/// # Returns
700///
701/// A 2D array containing the result of the operation applied element-wise
702///
703/// # Examples
704///
705/// ```
706/// use ndarray::array;
707/// use scirs2_core::ndarray_ext::broadcast_apply;
708///
709/// let a = array![[1, 2, 3], [4, 5, 6]];
710/// let b = array![10, 20, 30];
711/// let result = broadcast_apply(a.view(), b.view(), |x, y| x + y).unwrap();
712/// assert_eq!(result.shape(), &[2, 3]);
713/// assert_eq!(result[[0, 0]], 11);
714/// assert_eq!(result[[1, 2]], 36);
715/// ```
716#[allow(dead_code)]
717pub fn broadcast_apply<T, R, F>(
718    a: ArrayView<T, Ix2>,
719    b: ArrayView<T, Ix1>,
720    op: F,
721) -> Result<Array<R, Ix2>, &'static str>
722where
723    T: Clone + Default,
724    R: Clone + Default,
725    F: Fn(&T, &T) -> R,
726{
727    let (a_rows, a_cols) = (a.shape()[0], a.shape()[1]);
728    let b_len = b.len();
729
730    // Check that the arrays are broadcast compatible
731    if a_cols != b_len {
732        return Err("Arrays are not broadcast compatible");
733    }
734
735    // Create the result array
736    let mut result = Array::<R, Ix2>::default((a_rows, a_cols));
737
738    // Apply the operation element-wise
739    for i in 0..a_rows {
740        for j in 0..a_cols {
741            result[[i, j]] = op(&a[[i, j]], &b[j]);
742        }
743    }
744
745    Ok(result)
746}
747
748#[cfg(test)]
749mod tests {
750    use super::*;
751    use ndarray::array;
752
753    #[test]
754    fn test_reshape_2d() {
755        let a = array![[1, 2], [3, 4]];
756        let b = reshape_2d(a.view(), (4, 1)).unwrap();
757        assert_eq!(b.shape(), &[4, 1]);
758        assert_eq!(b[[0, 0]], 1);
759        assert_eq!(b[[1, 0]], 2);
760        assert_eq!(b[[2, 0]], 3);
761        assert_eq!(b[[3, 0]], 4);
762
763        // Test invalid shape
764        let result = reshape_2d(a.view(), (3, 1));
765        assert!(result.is_err());
766    }
767
768    #[test]
769    fn test_stack_2d() {
770        let a = array![[1, 2], [3, 4]];
771        let b = array![[5, 6], [7, 8]];
772
773        // Stack vertically (along axis 0)
774        let c = stack_2d(&[a.view(), b.view()], 0).unwrap();
775        assert_eq!(c.shape(), &[4, 2]);
776        assert_eq!(c[[0, 0]], 1);
777        assert_eq!(c[[1, 0]], 3);
778        assert_eq!(c[[2, 0]], 5);
779        assert_eq!(c[[3, 0]], 7);
780
781        // Stack horizontally (along axis 1)
782        let d = stack_2d(&[a.view(), b.view()], 1).unwrap();
783        assert_eq!(d.shape(), &[2, 4]);
784        assert_eq!(d[[0, 0]], 1);
785        assert_eq!(d[[0, 1]], 2);
786        assert_eq!(d[[0, 2]], 5);
787        assert_eq!(d[[0, 3]], 6);
788    }
789
790    #[test]
791    fn test_transpose_2d() {
792        let a = array![[1, 2, 3], [4, 5, 6]];
793        let b = transpose_2d(a.view());
794        assert_eq!(b.shape(), &[3, 2]);
795        assert_eq!(b[[0, 0]], 1);
796        assert_eq!(b[[0, 1]], 4);
797        assert_eq!(b[[1, 0]], 2);
798        assert_eq!(b[[2, 1]], 6);
799    }
800
801    #[test]
802    fn test_split_2d() {
803        let a = array![[1, 2, 3, 4], [5, 6, 7, 8]];
804
805        // Split along columns at index 2
806        let result = split_2d(a.view(), &[2], 1).unwrap();
807        assert_eq!(result.len(), 2);
808        assert_eq!(result[0].shape(), &[2, 2]);
809        assert_eq!(result[0][[0, 0]], 1);
810        assert_eq!(result[0][[0, 1]], 2);
811        assert_eq!(result[0][[1, 0]], 5);
812        assert_eq!(result[0][[1, 1]], 6);
813        assert_eq!(result[1].shape(), &[2, 2]);
814        assert_eq!(result[1][[0, 0]], 3);
815        assert_eq!(result[1][[0, 1]], 4);
816        assert_eq!(result[1][[1, 0]], 7);
817        assert_eq!(result[1][[1, 1]], 8);
818
819        // Split along rows at index 1
820        let result = split_2d(a.view(), &[1], 0).unwrap();
821        assert_eq!(result.len(), 2);
822        assert_eq!(result[0].shape(), &[1, 4]);
823        assert_eq!(result[1].shape(), &[1, 4]);
824    }
825
826    #[test]
827    fn test_take_2d() {
828        let a = array![[1, 2, 3], [4, 5, 6]];
829        let indices = array![0, 2];
830
831        // Take along columns
832        let result = take_2d(a.view(), indices.view(), 1).unwrap();
833        assert_eq!(result.shape(), &[2, 2]);
834        assert_eq!(result[[0, 0]], 1);
835        assert_eq!(result[[0, 1]], 3);
836        assert_eq!(result[[1, 0]], 4);
837        assert_eq!(result[[1, 1]], 6);
838    }
839
840    #[test]
841    fn test_mask_select() {
842        let a = array![[1, 2, 3], [4, 5, 6]];
843        let mask = array![[true, false, true], [false, true, false]];
844
845        let result = mask_select(a.view(), mask.view()).unwrap();
846        assert_eq!(result.shape(), &[3]);
847        assert_eq!(result[0], 1);
848        assert_eq!(result[1], 3);
849        assert_eq!(result[2], 5);
850    }
851
852    #[test]
853    fn test_fancy_index_2d() {
854        let a = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]];
855        let row_indices = array![0, 2];
856        let col_indices = array![0, 1];
857
858        let result = fancy_index_2d(a.view(), row_indices.view(), col_indices.view()).unwrap();
859        assert_eq!(result.shape(), &[2]);
860        assert_eq!(result[0], 1);
861        assert_eq!(result[1], 8);
862    }
863
864    #[test]
865    fn test_where_condition() {
866        let a = array![[1, 2, 3], [4, 5, 6]];
867        let result = where_condition(a.view(), |&x| x > 3).unwrap();
868        assert_eq!(result.shape(), &[3]);
869        assert_eq!(result[0], 4);
870        assert_eq!(result[1], 5);
871        assert_eq!(result[2], 6);
872    }
873
874    #[test]
875    fn test_broadcast_1d_to_2d() {
876        let a = array![1, 2, 3];
877
878        // Broadcast along rows (axis 0)
879        let b = broadcast_1d_to_2d(a.view(), 2, 0).unwrap();
880        assert_eq!(b.shape(), &[2, 3]);
881        assert_eq!(b[[0, 0]], 1);
882        assert_eq!(b[[0, 1]], 2);
883        assert_eq!(b[[1, 0]], 1);
884        assert_eq!(b[[1, 2]], 3);
885
886        // Broadcast along columns (axis 1)
887        let c = broadcast_1d_to_2d(a.view(), 2, 1).unwrap();
888        assert_eq!(c.shape(), &[3, 2]);
889        assert_eq!(c[[0, 0]], 1);
890        assert_eq!(c[[0, 1]], 1);
891        assert_eq!(c[[1, 0]], 2);
892        assert_eq!(c[[2, 1]], 3);
893    }
894
895    #[test]
896    fn test_broadcast_apply() {
897        let a = array![[1, 2, 3], [4, 5, 6]];
898        let b = array![10, 20, 30];
899
900        let result = broadcast_apply(a.view(), b.view(), |x, y| x + y).unwrap();
901        assert_eq!(result.shape(), &[2, 3]);
902        assert_eq!(result[[0, 0]], 11);
903        assert_eq!(result[[0, 1]], 22);
904        assert_eq!(result[[0, 2]], 33);
905        assert_eq!(result[[1, 0]], 14);
906        assert_eq!(result[[1, 1]], 25);
907        assert_eq!(result[[1, 2]], 36);
908
909        let result = broadcast_apply(a.view(), b.view(), |x, y| x * y).unwrap();
910        assert_eq!(result.shape(), &[2, 3]);
911        assert_eq!(result[[0, 0]], 10);
912        assert_eq!(result[[0, 1]], 40);
913        assert_eq!(result[[0, 2]], 90);
914        assert_eq!(result[[1, 0]], 40);
915        assert_eq!(result[[1, 1]], 100);
916        assert_eq!(result[[1, 2]], 180);
917    }
918}