scirs2_core/ndarray_ext/
mod.rs

1//! Extended ndarray operations for scientific computing
2//!
3//! This module provides additional functionality for ndarray to support
4//! the advanced array operations necessary for a complete SciPy port.
5//! It implements core `NumPy`-like features that are not available in the
6//! base ndarray crate.
7
8/// Re-export essential ndarray types for convenience
9pub use ndarray::{
10    s, Array, ArrayBase, ArrayView, ArrayViewMut, Axis, Data, DataMut, Dim, Dimension, Ix1, Ix2,
11    Ix3, Ix4, Ix5, Ix6, IxDyn, OwnedRepr, RemoveAxis, ScalarOperand, ShapeBuilder, ShapeError,
12    SliceInfo, ViewRepr, Zip,
13};
14
15// Re-export type aliases for common array sizes
16pub use ndarray::{ArcArray1, ArcArray2};
17pub use ndarray::{Array0, Array1, Array2, Array3, Array4, Array5, Array6, ArrayD};
18pub use ndarray::{
19    ArrayView0, ArrayView1, ArrayView2, ArrayView3, ArrayView4, ArrayView5, ArrayView6, ArrayViewD,
20};
21pub use ndarray::{
22    ArrayViewMut0, ArrayViewMut1, ArrayViewMut2, ArrayViewMut3, ArrayViewMut4, ArrayViewMut5,
23    ArrayViewMut6, ArrayViewMutD,
24};
25
26/// Advanced indexing operations (`NumPy`-like boolean masking, fancy indexing, etc.)
27pub mod indexing;
28
29/// Statistical functions for ndarray arrays (mean, median, variance, correlation, etc.)
30pub mod stats;
31
32/// Matrix operations (eye, diag, kron, etc.)
33pub mod matrix;
34
35/// Array manipulation operations (flip, roll, tile, repeat, etc.)
36pub mod manipulation;
37
38/// Reshape a 2D array to a new shape without copying data when possible
39///
40/// # Arguments
41///
42/// * `array` - The input array
43/// * `shape` - The new shape (rows, cols), which must be compatible with the original shape
44///
45/// # Returns
46///
47/// A reshaped array view of the input array if possible, or a new array otherwise
48///
49/// # Examples
50///
51/// ```
52/// use ndarray::array;
53/// use scirs2_core::ndarray_ext::reshape_2d;
54///
55/// let a = array![[1, 2], [3, 4]];
56/// let b = reshape_2d(a.view(), (4, 1)).unwrap();
57/// assert_eq!(b.shape(), &[4, 1]);
58/// assert_eq!(b[[0, 0]], 1);
59/// assert_eq!(b[[3, 0]], 4);
60/// ```
61#[allow(dead_code)]
62pub fn reshape_2d<T>(
63    array: ArrayView<T, Ix2>,
64    shape: (usize, usize),
65) -> Result<Array<T, Ix2>, &'static str>
66where
67    T: Clone + Default,
68{
69    let (rows, cols) = shape;
70    let total_elements = rows * cols;
71
72    // Check if the new shape is compatible with the original shape
73    if total_elements != array.len() {
74        return Err("New shape dimensions must match the total number of elements");
75    }
76
77    // Create a new array with the specified shape
78    let mut result = Array::<T, Ix2>::default(shape);
79
80    // Fill the result array with elements from the input array
81    let flat_iter = array.iter();
82    for (i, val) in flat_iter.enumerate() {
83        let r = i / cols;
84        let c = i % cols;
85        result[[r, c]] = val.clone();
86    }
87
88    Ok(result)
89}
90
91/// Stack 2D arrays along a given axis
92///
93/// # Arguments
94///
95/// * `arrays` - A slice of 2D arrays to stack
96/// * `axis` - The axis along which to stack (0 for rows, 1 for columns)
97///
98/// # Returns
99///
100/// A new array containing the stacked arrays
101///
102/// # Examples
103///
104/// ```
105/// use ndarray::{array, Axis};
106/// use scirs2_core::ndarray_ext::stack_2d;
107///
108/// let a = array![[1, 2], [3, 4]];
109/// let b = array![[5, 6], [7, 8]];
110/// let c = stack_2d(&[a.view(), b.view()], 0).unwrap();
111/// assert_eq!(c.shape(), &[4, 2]);
112/// ```
113#[allow(dead_code)]
114pub fn stack_2d<T>(arrays: &[ArrayView<T, Ix2>], axis: usize) -> Result<Array<T, Ix2>, &'static str>
115where
116    T: Clone + Default,
117{
118    if arrays.is_empty() {
119        return Err("No _arrays provided for stacking");
120    }
121
122    // Validate that all _arrays have the same shape
123    let firstshape = arrays[0].shape();
124    for array in arrays.iter().skip(1) {
125        if array.shape() != firstshape {
126            return Err("All _arrays must have the same shape for stacking");
127        }
128    }
129
130    let (rows, cols) = (firstshape[0], firstshape[1]);
131
132    // Calculate the new shape
133    let (new_rows, new_cols) = match axis {
134        0 => (rows * arrays.len(), cols), // Stack vertically
135        1 => (rows, cols * arrays.len()), // Stack horizontally
136        _ => return Err("Axis must be 0 or 1 for 2D _arrays"),
137    };
138
139    // Create a new array to hold the stacked result
140    let mut result = Array::<T, Ix2>::default((new_rows, new_cols));
141
142    // Copy data from the input _arrays to the result
143    match axis {
144        0 => {
145            // Stack vertically (along rows)
146            for (array_idx, array) in arrays.iter().enumerate() {
147                let start_row = array_idx * rows;
148                for r in 0..rows {
149                    for c in 0..cols {
150                        result[[start_row + r, c]] = array[[r, c]].clone();
151                    }
152                }
153            }
154        }
155        1 => {
156            // Stack horizontally (along columns)
157            for (array_idx, array) in arrays.iter().enumerate() {
158                let start_col = array_idx * cols;
159                for r in 0..rows {
160                    for c in 0..cols {
161                        result[[r, start_col + c]] = array[[r, c]].clone();
162                    }
163                }
164            }
165        }
166        _ => unreachable!(),
167    }
168
169    Ok(result)
170}
171
172/// Swap axes (transpose) of a 2D array
173///
174/// # Arguments
175///
176/// * `array` - The input 2D array
177///
178/// # Returns
179///
180/// A view of the input array with the axes swapped
181///
182/// # Examples
183///
184/// ```
185/// use ndarray::array;
186/// use scirs2_core::ndarray_ext::transpose_2d;
187///
188/// let a = array![[1, 2, 3], [4, 5, 6]];
189/// let b = transpose_2d(a.view());
190/// assert_eq!(b.shape(), &[3, 2]);
191/// assert_eq!(b[[0, 0]], 1);
192/// assert_eq!(b[[0, 1]], 4);
193/// ```
194#[allow(dead_code)]
195pub fn transpose_2d<T>(array: ArrayView<T, Ix2>) -> Array<T, Ix2>
196where
197    T: Clone,
198{
199    array.t().to_owned()
200}
201
202/// Split a 2D array into multiple sub-arrays along a given axis
203///
204/// # Arguments
205///
206/// * `array` - The input 2D array to split
207/// * `indices` - Indices where the array should be split
208/// * `axis` - The axis along which to split (0 for rows, 1 for columns)
209///
210/// # Returns
211///
212/// A vector of arrays resulting from the split
213///
214/// # Examples
215///
216/// ```
217/// use ndarray::array;
218/// use scirs2_core::ndarray_ext::split_2d;
219///
220/// let a = array![[1, 2, 3, 4], [5, 6, 7, 8]];
221/// let result = split_2d(a.view(), &[2], 1).unwrap();
222/// assert_eq!(result.len(), 2);
223/// assert_eq!(result[0].shape(), &[2, 2]);
224/// assert_eq!(result[1].shape(), &[2, 2]);
225/// ```
226#[allow(dead_code)]
227pub fn split_2d<T>(
228    array: ArrayView<T, Ix2>,
229    indices: &[usize],
230    axis: usize,
231) -> Result<Vec<Array<T, Ix2>>, &'static str>
232where
233    T: Clone + Default,
234{
235    if indices.is_empty() {
236        return Ok(vec![array.to_owned()]);
237    }
238
239    let (rows, cols) = (array.shape()[0], array.shape()[1]);
240    let axis_len = if axis == 0 { rows } else { cols };
241
242    // Validate indices
243    for &idx in indices {
244        if idx >= axis_len {
245            return Err("Split index out of bounds");
246        }
247    }
248
249    // Sort indices to ensure they're in ascending order
250    let mut sorted_indices = indices.to_vec();
251    sorted_indices.sort_unstable();
252
253    // Calculate the sub-array boundaries
254    let mut starts = vec![0];
255    starts.extend_from_slice(&sorted_indices);
256
257    let mut ends = sorted_indices.clone();
258    ends.push(axis_len);
259
260    // Create the split sub-arrays
261    let mut result = Vec::with_capacity(starts.len());
262
263    match axis {
264        0 => {
265            // Split along rows
266            for (&start, &end) in starts.iter().zip(ends.iter()) {
267                let sub_rows = end - start;
268                let mut sub_array = Array::<T, Ix2>::default((sub_rows, cols));
269
270                for r in 0..sub_rows {
271                    for c in 0..cols {
272                        sub_array[[r, c]] = array[[start + r, c]].clone();
273                    }
274                }
275
276                result.push(sub_array);
277            }
278        }
279        1 => {
280            // Split along columns
281            for (&start, &end) in starts.iter().zip(ends.iter()) {
282                let sub_cols = end - start;
283                let mut sub_array = Array::<T, Ix2>::default((rows, sub_cols));
284
285                for r in 0..rows {
286                    for c in 0..sub_cols {
287                        sub_array[[r, c]] = array[[r, start + c]].clone();
288                    }
289                }
290
291                result.push(sub_array);
292            }
293        }
294        _ => return Err("Axis must be 0 or 1 for 2D arrays"),
295    }
296
297    Ok(result)
298}
299
300/// Take elements from a 2D array along a given axis using indices from another array
301///
302/// # Arguments
303///
304/// * `array` - The input 2D array
305/// * `indices` - Array of indices to take
306/// * `axis` - The axis along which to take values (0 for rows, 1 for columns)
307///
308/// # Returns
309///
310/// An array of values at the specified indices along the given axis
311///
312/// # Examples
313///
314/// ```
315/// use ndarray::array;
316/// use scirs2_core::ndarray_ext::take_2d;
317///
318/// let a = array![[1, 2, 3], [4, 5, 6]];
319/// let indices = array![0, 2];
320/// let result = take_2d(a.view(), indices.view(), 1).unwrap();
321/// assert_eq!(result.shape(), &[2, 2]);
322/// assert_eq!(result[[0, 0]], 1);
323/// assert_eq!(result[[0, 1]], 3);
324/// ```
325#[allow(dead_code)]
326pub fn take_2d<T>(
327    array: ArrayView<T, Ix2>,
328    indices: ArrayView<usize, Ix1>,
329    axis: usize,
330) -> Result<Array<T, Ix2>, &'static str>
331where
332    T: Clone + Default,
333{
334    let (rows, cols) = (array.shape()[0], array.shape()[1]);
335    let axis_len = if axis == 0 { rows } else { cols };
336
337    // Check that indices are within bounds
338    for &idx in indices.iter() {
339        if idx >= axis_len {
340            return Err("Index out of bounds");
341        }
342    }
343
344    // Create the result array with the appropriate shape
345    let (result_rows, result_cols) = match axis {
346        0 => (indices.len(), cols),
347        1 => (rows, indices.len()),
348        _ => return Err("Axis must be 0 or 1 for 2D arrays"),
349    };
350
351    let mut result = Array::<T, Ix2>::default((result_rows, result_cols));
352
353    // Fill the result array
354    match axis {
355        0 => {
356            // Take along rows
357            for (i, &idx) in indices.iter().enumerate() {
358                for j in 0..cols {
359                    result[[i, j]] = array[[idx, j]].clone();
360                }
361            }
362        }
363        1 => {
364            // Take along columns
365            for i in 0..rows {
366                for (j, &idx) in indices.iter().enumerate() {
367                    result[[i, j]] = array[[i, idx]].clone();
368                }
369            }
370        }
371        _ => unreachable!(),
372    }
373
374    Ok(result)
375}
376
377/// Filter an array using a boolean mask
378///
379/// # Arguments
380///
381/// * `array` - The input array
382/// * `mask` - Boolean mask of the same shape as the array
383///
384/// # Returns
385///
386/// A 1D array containing the elements where the mask is true
387///
388/// # Examples
389///
390/// ```
391/// use ndarray::array;
392/// use scirs2_core::ndarray_ext::mask_select;
393///
394/// let a = array![[1, 2, 3], [4, 5, 6]];
395/// let mask = array![[true, false, true], [false, true, false]];
396/// let result = mask_select(a.view(), mask.view()).unwrap();
397/// assert_eq!(result.shape(), &[3]);
398/// assert_eq!(result[0], 1);
399/// assert_eq!(result[1], 3);
400/// assert_eq!(result[2], 5);
401/// ```
402#[allow(dead_code)]
403pub fn mask_select<T>(
404    array: ArrayView<T, Ix2>,
405    mask: ArrayView<bool, Ix2>,
406) -> Result<Array<T, Ix1>, &'static str>
407where
408    T: Clone + Default,
409{
410    // Check that the mask has the same shape as the array
411    if array.shape() != mask.shape() {
412        return Err("Mask shape must match array shape");
413    }
414
415    // Count the number of true values in the mask
416    let true_count = mask.iter().filter(|&&x| x).count();
417
418    // Create the result array
419    let mut result = Array::<T, Ix1>::default(true_count);
420
421    // Fill the result array with elements where the mask is true
422    let mut idx = 0;
423    for (val, &m) in array.iter().zip(mask.iter()) {
424        if m {
425            result[idx] = val.clone();
426            idx += 1;
427        }
428    }
429
430    Ok(result)
431}
432
433/// Index a 2D array with a list of index arrays (fancy indexing)
434///
435/// # Arguments
436///
437/// * `array` - The input 2D array
438/// * `row_indices` - Array of row indices
439/// * `col_indices` - Array of column indices
440///
441/// # Returns
442///
443/// A 1D array containing the elements at the specified indices
444///
445/// # Examples
446///
447/// ```
448/// use ndarray::array;
449/// use scirs2_core::ndarray_ext::fancy_index_2d;
450///
451/// let a = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]];
452/// let row_indices = array![0, 2];
453/// let col_indices = array![0, 1];
454/// let result = fancy_index_2d(a.view(), row_indices.view(), col_indices.view()).unwrap();
455/// assert_eq!(result.shape(), &[2]);
456/// assert_eq!(result[0], 1);
457/// assert_eq!(result[1], 8);
458/// ```
459#[allow(dead_code)]
460pub fn fancy_index_2d<T>(
461    array: ArrayView<T, Ix2>,
462    row_indices: ArrayView<usize, Ix1>,
463    col_indices: ArrayView<usize, Ix1>,
464) -> Result<Array<T, Ix1>, &'static str>
465where
466    T: Clone + Default,
467{
468    // Check that all index arrays have the same length
469    let result_size = row_indices.len();
470    if col_indices.len() != result_size {
471        return Err("Row and column index arrays must have the same length");
472    }
473
474    let (rows, cols) = (array.shape()[0], array.shape()[1]);
475
476    // Check that _indices are within bounds
477    for &idx in row_indices.iter() {
478        if idx >= rows {
479            return Err("Row index out of bounds");
480        }
481    }
482
483    for &idx in col_indices.iter() {
484        if idx >= cols {
485            return Err("Column index out of bounds");
486        }
487    }
488
489    // Create the result array
490    let mut result = Array::<T, Ix1>::default(result_size);
491
492    // Fill the result array
493    for i in 0..result_size {
494        let row = row_indices[i];
495        let col = col_indices[i];
496        result[i] = array[[row, col]].clone();
497    }
498
499    Ok(result)
500}
501
502/// Select elements from an array where a condition is true
503///
504/// # Arguments
505///
506/// * `array` - The input array
507/// * `condition` - A function that takes a reference to an element and returns a bool
508///
509/// # Returns
510///
511/// A 1D array containing the elements where the condition is true
512///
513/// # Examples
514///
515/// ```
516/// use ndarray::array;
517/// use scirs2_core::ndarray_ext::where_condition;
518///
519/// let a = array![[1, 2, 3], [4, 5, 6]];
520/// let result = where_condition(a.view(), |&x| x > 3).unwrap();
521/// assert_eq!(result.shape(), &[3]);
522/// assert_eq!(result[0], 4);
523/// assert_eq!(result[1], 5);
524/// assert_eq!(result[2], 6);
525/// ```
526#[allow(dead_code)]
527pub fn where_condition<T, F>(
528    array: ArrayView<T, Ix2>,
529    condition: F,
530) -> Result<Array<T, Ix1>, &'static str>
531where
532    T: Clone + Default,
533    F: Fn(&T) -> bool,
534{
535    // Build a boolean mask array based on the condition
536    let mask = array.map(condition);
537
538    // Use the mask_select function to select elements
539    mask_select(array, mask.view())
540}
541
542/// Check if two shapes are broadcast compatible
543///
544/// # Arguments
545///
546/// * `shape1` - First shape as a slice
547/// * `shape2` - Second shape as a slice
548///
549/// # Returns
550///
551/// `true` if the shapes are broadcast compatible, `false` otherwise
552///
553/// # Examples
554///
555/// ```
556/// use scirs2_core::ndarray_ext::is_broadcast_compatible;
557///
558/// assert!(is_broadcast_compatible(&[2, 3], &[3]));
559/// // This example has dimensions that don't match (5 vs 3 in dimension 0)
560/// // and aren't broadcasting compatible (neither is 1)
561/// assert!(!is_broadcast_compatible(&[5, 1, 4], &[3, 1, 1]));
562/// assert!(!is_broadcast_compatible(&[2, 3], &[4]));
563/// ```
564#[allow(dead_code)]
565pub fn is_broadcast_compatible(shape1: &[usize], shape2: &[usize]) -> bool {
566    // Align shapes to have the same dimensionality by prepending with 1s
567    let max_dim = shape1.len().max(shape2.len());
568
569    // Fill in 1s for missing dimensions
570    let get_dim = |shape: &[usize], i: usize| -> usize {
571        let offset = max_dim - shape.len();
572        if i < offset {
573            1 // Implicit dimension of size 1
574        } else {
575            shape[i - offset]
576        }
577    };
578
579    // Check broadcasting rules for each dimension
580    for i in 0..max_dim {
581        let dim1 = get_dim(shape1, i);
582        let dim2 = get_dim(shape2, i);
583
584        // Dimensions must either be the same or one of them must be 1
585        if dim1 != dim2 && dim1 != 1 && dim2 != 1 {
586            return false;
587        }
588    }
589
590    true
591}
592
593/// Calculate the broadcasted shape from two input shapes
594///
595/// # Arguments
596///
597/// * `shape1` - First shape as a slice
598/// * `shape2` - Second shape as a slice
599///
600/// # Returns
601///
602/// The broadcasted shape as a `Vec<usize>`, or `None` if the shapes are incompatible
603#[allow(dead_code)]
604pub fn broadcastshape(shape1: &[usize], shape2: &[usize]) -> Option<Vec<usize>> {
605    if !is_broadcast_compatible(shape1, shape2) {
606        return None;
607    }
608
609    // Align shapes to have the same dimensionality
610    let max_dim = shape1.len().max(shape2.len());
611    let mut result = Vec::with_capacity(max_dim);
612
613    // Fill in 1s for missing dimensions
614    let get_dim = |shape: &[usize], i: usize| -> usize {
615        let offset = max_dim - shape.len();
616        if i < offset {
617            1 // Implicit dimension of size 1
618        } else {
619            shape[i - offset]
620        }
621    };
622
623    // Calculate the broadcasted shape
624    for i in 0..max_dim {
625        let dim1 = get_dim(shape1, i);
626        let dim2 = get_dim(shape2, i);
627
628        // The broadcasted dimension is the maximum of the two
629        result.push(dim1.max(dim2));
630    }
631
632    Some(result)
633}
634
635/// Broadcast a 1D array to a 2D shape by repeating it along the specified axis
636///
637/// # Arguments
638///
639/// * `array` - The input 1D array
640/// * `repeats` - Number of times to repeat the array
641/// * `axis` - Axis along which to repeat (0 for rows, 1 for columns)
642///
643/// # Returns
644///
645/// A 2D array with the input array repeated along the specified axis
646///
647/// # Examples
648///
649/// ```
650/// use ndarray::array;
651/// use scirs2_core::ndarray_ext::broadcast_1d_to_2d;
652///
653/// let a = array![1, 2, 3];
654/// let b = broadcast_1d_to_2d(a.view(), 2, 0).unwrap();
655/// assert_eq!(b.shape(), &[2, 3]);
656/// assert_eq!(b[[0, 0]], 1);
657/// assert_eq!(b[[1, 0]], 1);
658/// ```
659#[allow(dead_code)]
660pub fn broadcast_1d_to_2d<T>(
661    array: ArrayView<T, Ix1>,
662    repeats: usize,
663    axis: usize,
664) -> Result<Array<T, Ix2>, &'static str>
665where
666    T: Clone + Default,
667{
668    let len = array.len();
669
670    // Create the result array with the appropriate shape
671    let (rows, cols) = match axis {
672        0 => (repeats, len), // Broadcast along rows
673        1 => (len, repeats), // Broadcast along columns
674        _ => return Err("Axis must be 0 or 1"),
675    };
676
677    let mut result = Array::<T, Ix2>::default((rows, cols));
678
679    // Fill the result array
680    match axis {
681        0 => {
682            // Broadcast along rows
683            for i in 0..repeats {
684                for j in 0..len {
685                    result[[i, j]] = array[j].clone();
686                }
687            }
688        }
689        1 => {
690            // Broadcast along columns
691            for i in 0..len {
692                for j in 0..repeats {
693                    result[[i, j]] = array[i].clone();
694                }
695            }
696        }
697        _ => unreachable!(),
698    }
699
700    Ok(result)
701}
702
703/// Apply an element-wise binary operation to two arrays with broadcasting
704///
705/// # Arguments
706///
707/// * `a` - First array (2D)
708/// * `b` - Second array (can be 1D or 2D)
709/// * `op` - Binary operation to apply to each pair of elements
710///
711/// # Returns
712///
713/// A 2D array containing the result of the operation applied element-wise
714///
715/// # Examples
716///
717/// ```
718/// use ndarray::array;
719/// use scirs2_core::ndarray_ext::broadcast_apply;
720///
721/// let a = array![[1, 2, 3], [4, 5, 6]];
722/// let b = array![10, 20, 30];
723/// let result = broadcast_apply(a.view(), b.view(), |x, y| x + y).unwrap();
724/// assert_eq!(result.shape(), &[2, 3]);
725/// assert_eq!(result[[0, 0]], 11);
726/// assert_eq!(result[[1, 2]], 36);
727/// ```
728#[allow(dead_code)]
729pub fn broadcast_apply<T, R, F>(
730    a: ArrayView<T, Ix2>,
731    b: ArrayView<T, Ix1>,
732    op: F,
733) -> Result<Array<R, Ix2>, &'static str>
734where
735    T: Clone + Default,
736    R: Clone + Default,
737    F: Fn(&T, &T) -> R,
738{
739    let (a_rows, a_cols) = (a.shape()[0], a.shape()[1]);
740    let b_len = b.len();
741
742    // Check that the arrays are broadcast compatible
743    if a_cols != b_len {
744        return Err("Arrays are not broadcast compatible");
745    }
746
747    // Create the result array
748    let mut result = Array::<R, Ix2>::default((a_rows, a_cols));
749
750    // Apply the operation element-wise
751    for i in 0..a_rows {
752        for j in 0..a_cols {
753            result[[i, j]] = op(&a[[i, j]], &b[j]);
754        }
755    }
756
757    Ok(result)
758}
759
760#[cfg(test)]
761mod tests {
762    use super::*;
763    use ndarray::array;
764
765    #[test]
766    fn test_reshape_2d() {
767        let a = array![[1, 2], [3, 4]];
768        let b = reshape_2d(a.view(), (4, 1)).unwrap();
769        assert_eq!(b.shape(), &[4, 1]);
770        assert_eq!(b[[0, 0]], 1);
771        assert_eq!(b[[1, 0]], 2);
772        assert_eq!(b[[2, 0]], 3);
773        assert_eq!(b[[3, 0]], 4);
774
775        // Test invalid shape
776        let result = reshape_2d(a.view(), (3, 1));
777        assert!(result.is_err());
778    }
779
780    #[test]
781    fn test_stack_2d() {
782        let a = array![[1, 2], [3, 4]];
783        let b = array![[5, 6], [7, 8]];
784
785        // Stack vertically (along axis 0)
786        let c = stack_2d(&[a.view(), b.view()], 0).unwrap();
787        assert_eq!(c.shape(), &[4, 2]);
788        assert_eq!(c[[0, 0]], 1);
789        assert_eq!(c[[1, 0]], 3);
790        assert_eq!(c[[2, 0]], 5);
791        assert_eq!(c[[3, 0]], 7);
792
793        // Stack horizontally (along axis 1)
794        let d = stack_2d(&[a.view(), b.view()], 1).unwrap();
795        assert_eq!(d.shape(), &[2, 4]);
796        assert_eq!(d[[0, 0]], 1);
797        assert_eq!(d[[0, 1]], 2);
798        assert_eq!(d[[0, 2]], 5);
799        assert_eq!(d[[0, 3]], 6);
800    }
801
802    #[test]
803    fn test_transpose_2d() {
804        let a = array![[1, 2, 3], [4, 5, 6]];
805        let b = transpose_2d(a.view());
806        assert_eq!(b.shape(), &[3, 2]);
807        assert_eq!(b[[0, 0]], 1);
808        assert_eq!(b[[0, 1]], 4);
809        assert_eq!(b[[1, 0]], 2);
810        assert_eq!(b[[2, 1]], 6);
811    }
812
813    #[test]
814    fn test_split_2d() {
815        let a = array![[1, 2, 3, 4], [5, 6, 7, 8]];
816
817        // Split along columns at index 2
818        let result = split_2d(a.view(), &[2], 1).unwrap();
819        assert_eq!(result.len(), 2);
820        assert_eq!(result[0].shape(), &[2, 2]);
821        assert_eq!(result[0][[0, 0]], 1);
822        assert_eq!(result[0][[0, 1]], 2);
823        assert_eq!(result[0][[1, 0]], 5);
824        assert_eq!(result[0][[1, 1]], 6);
825        assert_eq!(result[1].shape(), &[2, 2]);
826        assert_eq!(result[1][[0, 0]], 3);
827        assert_eq!(result[1][[0, 1]], 4);
828        assert_eq!(result[1][[1, 0]], 7);
829        assert_eq!(result[1][[1, 1]], 8);
830
831        // Split along rows at index 1
832        let result = split_2d(a.view(), &[1], 0).unwrap();
833        assert_eq!(result.len(), 2);
834        assert_eq!(result[0].shape(), &[1, 4]);
835        assert_eq!(result[1].shape(), &[1, 4]);
836    }
837
838    #[test]
839    fn test_take_2d() {
840        let a = array![[1, 2, 3], [4, 5, 6]];
841        let indices = array![0, 2];
842
843        // Take along columns
844        let result = take_2d(a.view(), indices.view(), 1).unwrap();
845        assert_eq!(result.shape(), &[2, 2]);
846        assert_eq!(result[[0, 0]], 1);
847        assert_eq!(result[[0, 1]], 3);
848        assert_eq!(result[[1, 0]], 4);
849        assert_eq!(result[[1, 1]], 6);
850    }
851
852    #[test]
853    fn test_mask_select() {
854        let a = array![[1, 2, 3], [4, 5, 6]];
855        let mask = array![[true, false, true], [false, true, false]];
856
857        let result = mask_select(a.view(), mask.view()).unwrap();
858        assert_eq!(result.shape(), &[3]);
859        assert_eq!(result[0], 1);
860        assert_eq!(result[1], 3);
861        assert_eq!(result[2], 5);
862    }
863
864    #[test]
865    fn test_fancy_index_2d() {
866        let a = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]];
867        let row_indices = array![0, 2];
868        let col_indices = array![0, 1];
869
870        let result = fancy_index_2d(a.view(), row_indices.view(), col_indices.view()).unwrap();
871        assert_eq!(result.shape(), &[2]);
872        assert_eq!(result[0], 1);
873        assert_eq!(result[1], 8);
874    }
875
876    #[test]
877    fn test_where_condition() {
878        let a = array![[1, 2, 3], [4, 5, 6]];
879        let result = where_condition(a.view(), |&x| x > 3).unwrap();
880        assert_eq!(result.shape(), &[3]);
881        assert_eq!(result[0], 4);
882        assert_eq!(result[1], 5);
883        assert_eq!(result[2], 6);
884    }
885
886    #[test]
887    fn test_broadcast_1d_to_2d() {
888        let a = array![1, 2, 3];
889
890        // Broadcast along rows (axis 0)
891        let b = broadcast_1d_to_2d(a.view(), 2, 0).unwrap();
892        assert_eq!(b.shape(), &[2, 3]);
893        assert_eq!(b[[0, 0]], 1);
894        assert_eq!(b[[0, 1]], 2);
895        assert_eq!(b[[1, 0]], 1);
896        assert_eq!(b[[1, 2]], 3);
897
898        // Broadcast along columns (axis 1)
899        let c = broadcast_1d_to_2d(a.view(), 2, 1).unwrap();
900        assert_eq!(c.shape(), &[3, 2]);
901        assert_eq!(c[[0, 0]], 1);
902        assert_eq!(c[[0, 1]], 1);
903        assert_eq!(c[[1, 0]], 2);
904        assert_eq!(c[[2, 1]], 3);
905    }
906
907    #[test]
908    fn test_broadcast_apply() {
909        let a = array![[1, 2, 3], [4, 5, 6]];
910        let b = array![10, 20, 30];
911
912        let result = broadcast_apply(a.view(), b.view(), |x, y| x + y).unwrap();
913        assert_eq!(result.shape(), &[2, 3]);
914        assert_eq!(result[[0, 0]], 11);
915        assert_eq!(result[[0, 1]], 22);
916        assert_eq!(result[[0, 2]], 33);
917        assert_eq!(result[[1, 0]], 14);
918        assert_eq!(result[[1, 1]], 25);
919        assert_eq!(result[[1, 2]], 36);
920
921        let result = broadcast_apply(a.view(), b.view(), |x, y| x * y).unwrap();
922        assert_eq!(result.shape(), &[2, 3]);
923        assert_eq!(result[[0, 0]], 10);
924        assert_eq!(result[[0, 1]], 40);
925        assert_eq!(result[[0, 2]], 90);
926        assert_eq!(result[[1, 0]], 40);
927        assert_eq!(result[[1, 1]], 100);
928        assert_eq!(result[[1, 2]], 180);
929    }
930}