scirs2_core/ndarray_ext/
mod.rs

1//! Extended ndarray operations for scientific computing
2//!
3//! This module provides additional functionality for ndarray to support
4//! the advanced array operations necessary for a complete SciPy port.
5//! It implements core `NumPy`-like features that are not available in the
6//! base ndarray crate.
7
8/// Re-export essential ndarray types and macros for convenience
9pub use ndarray::{
10    s, Array, ArrayBase, ArrayView, ArrayViewMut, Axis, Data, DataMut, Dim, Dimension, Ix1, Ix2,
11    Ix3, Ix4, Ix5, Ix6, IxDyn, OwnedRepr, RemoveAxis, ScalarOperand, ShapeBuilder, ShapeError,
12    SliceInfo, ViewRepr, Zip,
13};
14
15// Re-export the array! macro for convenient array creation
16// This addresses the common issue where users expect array! to be available in scirs2-core
17// instead of requiring import from scirs2_autograd::ndarray
18pub use ndarray::{arr1, arr2, array};
19
20// Re-export type aliases for common array sizes
21pub use ndarray::{ArcArray1, ArcArray2};
22pub use ndarray::{Array0, Array1, Array2, Array3, Array4, Array5, Array6, ArrayD};
23pub use ndarray::{
24    ArrayView0, ArrayView1, ArrayView2, ArrayView3, ArrayView4, ArrayView5, ArrayView6, ArrayViewD,
25};
26pub use ndarray::{
27    ArrayViewMut0, ArrayViewMut1, ArrayViewMut2, ArrayViewMut3, ArrayViewMut4, ArrayViewMut5,
28    ArrayViewMut6, ArrayViewMutD,
29};
30
31/// Advanced indexing operations (`NumPy`-like boolean masking, fancy indexing, etc.)
32pub mod indexing;
33
34/// Statistical functions for ndarray arrays (mean, median, variance, correlation, etc.)
35pub mod stats;
36
37/// Matrix operations (eye, diag, kron, etc.)
38pub mod matrix;
39
40/// Array manipulation operations (flip, roll, tile, repeat, etc.)
41pub mod manipulation;
42
43/// Random array generation with full ndarray-rand compatibility and scientific computing extensions
44#[cfg(feature = "random")]
45pub mod random;
46
47/// Reshape a 2D array to a new shape without copying data when possible
48///
49/// # Arguments
50///
51/// * `array` - The input array
52/// * `shape` - The new shape (rows, cols), which must be compatible with the original shape
53///
54/// # Returns
55///
56/// A reshaped array view of the input array if possible, or a new array otherwise
57///
58/// # Examples
59///
60/// ```
61/// use ndarray::array;
62/// use scirs2_core::ndarray_ext::reshape_2d;
63///
64/// let a = array![[1, 2], [3, 4]];
65/// let b = reshape_2d(a.view(), (4, 1)).unwrap();
66/// assert_eq!(b.shape(), &[4, 1]);
67/// assert_eq!(b[[0, 0]], 1);
68/// assert_eq!(b[[3, 0]], 4);
69/// ```
70#[allow(dead_code)]
71pub fn reshape_2d<T>(
72    array: ArrayView<T, Ix2>,
73    shape: (usize, usize),
74) -> Result<Array<T, Ix2>, &'static str>
75where
76    T: Clone + Default,
77{
78    let (rows, cols) = shape;
79    let total_elements = rows * cols;
80
81    // Check if the new shape is compatible with the original shape
82    if total_elements != array.len() {
83        return Err("New shape dimensions must match the total number of elements");
84    }
85
86    // Create a new array with the specified shape
87    let mut result = Array::<T, Ix2>::default(shape);
88
89    // Fill the result array with elements from the input array
90    let flat_iter = array.iter();
91    for (i, val) in flat_iter.enumerate() {
92        let r = i / cols;
93        let c = i % cols;
94        result[[r, c]] = val.clone();
95    }
96
97    Ok(result)
98}
99
100/// Stack 2D arrays along a given axis
101///
102/// # Arguments
103///
104/// * `arrays` - A slice of 2D arrays to stack
105/// * `axis` - The axis along which to stack (0 for rows, 1 for columns)
106///
107/// # Returns
108///
109/// A new array containing the stacked arrays
110///
111/// # Examples
112///
113/// ```
114/// use ndarray::{array, Axis};
115/// use scirs2_core::ndarray_ext::stack_2d;
116///
117/// let a = array![[1, 2], [3, 4]];
118/// let b = array![[5, 6], [7, 8]];
119/// let c = stack_2d(&[a.view(), b.view()], 0).unwrap();
120/// assert_eq!(c.shape(), &[4, 2]);
121/// ```
122#[allow(dead_code)]
123pub fn stack_2d<T>(arrays: &[ArrayView<T, Ix2>], axis: usize) -> Result<Array<T, Ix2>, &'static str>
124where
125    T: Clone + Default,
126{
127    if arrays.is_empty() {
128        return Err("No _arrays provided for stacking");
129    }
130
131    // Validate that all _arrays have the same shape
132    let firstshape = arrays[0].shape();
133    for array in arrays.iter().skip(1) {
134        if array.shape() != firstshape {
135            return Err("All _arrays must have the same shape for stacking");
136        }
137    }
138
139    let (rows, cols) = (firstshape[0], firstshape[1]);
140
141    // Calculate the new shape
142    let (new_rows, new_cols) = match axis {
143        0 => (rows * arrays.len(), cols), // Stack vertically
144        1 => (rows, cols * arrays.len()), // Stack horizontally
145        _ => return Err("Axis must be 0 or 1 for 2D _arrays"),
146    };
147
148    // Create a new array to hold the stacked result
149    let mut result = Array::<T, Ix2>::default((new_rows, new_cols));
150
151    // Copy data from the input _arrays to the result
152    match axis {
153        0 => {
154            // Stack vertically (along rows)
155            for (array_idx, array) in arrays.iter().enumerate() {
156                let start_row = array_idx * rows;
157                for r in 0..rows {
158                    for c in 0..cols {
159                        result[[start_row + r, c]] = array[[r, c]].clone();
160                    }
161                }
162            }
163        }
164        1 => {
165            // Stack horizontally (along columns)
166            for (array_idx, array) in arrays.iter().enumerate() {
167                let start_col = array_idx * cols;
168                for r in 0..rows {
169                    for c in 0..cols {
170                        result[[r, start_col + c]] = array[[r, c]].clone();
171                    }
172                }
173            }
174        }
175        _ => unreachable!(),
176    }
177
178    Ok(result)
179}
180
181/// Swap axes (transpose) of a 2D array
182///
183/// # Arguments
184///
185/// * `array` - The input 2D array
186///
187/// # Returns
188///
189/// A view of the input array with the axes swapped
190///
191/// # Examples
192///
193/// ```ignore
194/// use ndarray::array;
195/// use scirs2_core::ndarray_ext::transpose_2d;
196///
197/// let a = array![[1, 2, 3], [4, 5, 6]];
198/// let b = transpose_2d(a.view());
199/// assert_eq!(b.shape(), &[3, 2]);
200/// assert_eq!(b[[0, 0]], 1);
201/// assert_eq!(b[[0, 1]], 4);
202/// ```
203#[allow(dead_code)]
204pub fn transpose_2d<T>(array: ArrayView<T, Ix2>) -> Array<T, Ix2>
205where
206    T: Clone,
207{
208    array.t().to_owned()
209}
210
211/// Split a 2D array into multiple sub-arrays along a given axis
212///
213/// # Arguments
214///
215/// * `array` - The input 2D array to split
216/// * `indices` - Indices where the array should be split
217/// * `axis` - The axis along which to split (0 for rows, 1 for columns)
218///
219/// # Returns
220///
221/// A vector of arrays resulting from the split
222///
223/// # Examples
224///
225/// ```
226/// use ndarray::array;
227/// use scirs2_core::ndarray_ext::split_2d;
228///
229/// let a = array![[1, 2, 3, 4], [5, 6, 7, 8]];
230/// let result = split_2d(a.view(), &[2], 1).unwrap();
231/// assert_eq!(result.len(), 2);
232/// assert_eq!(result[0].shape(), &[2, 2]);
233/// assert_eq!(result[1].shape(), &[2, 2]);
234/// ```
235#[allow(dead_code)]
236pub fn split_2d<T>(
237    array: ArrayView<T, Ix2>,
238    indices: &[usize],
239    axis: usize,
240) -> Result<Vec<Array<T, Ix2>>, &'static str>
241where
242    T: Clone + Default,
243{
244    if indices.is_empty() {
245        return Ok(vec![array.to_owned()]);
246    }
247
248    let (rows, cols) = (array.shape()[0], array.shape()[1]);
249    let axis_len = if axis == 0 { rows } else { cols };
250
251    // Validate indices
252    for &idx in indices {
253        if idx >= axis_len {
254            return Err("Split index out of bounds");
255        }
256    }
257
258    // Sort indices to ensure they're in ascending order
259    let mut sorted_indices = indices.to_vec();
260    sorted_indices.sort_unstable();
261
262    // Calculate the sub-array boundaries
263    let mut starts = vec![0];
264    starts.extend_from_slice(&sorted_indices);
265
266    let mut ends = sorted_indices.clone();
267    ends.push(axis_len);
268
269    // Create the split sub-arrays
270    let mut result = Vec::with_capacity(starts.len());
271
272    match axis {
273        0 => {
274            // Split along rows
275            for (&start, &end) in starts.iter().zip(ends.iter()) {
276                let sub_rows = end - start;
277                let mut sub_array = Array::<T, Ix2>::default((sub_rows, cols));
278
279                for r in 0..sub_rows {
280                    for c in 0..cols {
281                        sub_array[[r, c]] = array[[start + r, c]].clone();
282                    }
283                }
284
285                result.push(sub_array);
286            }
287        }
288        1 => {
289            // Split along columns
290            for (&start, &end) in starts.iter().zip(ends.iter()) {
291                let sub_cols = end - start;
292                let mut sub_array = Array::<T, Ix2>::default((rows, sub_cols));
293
294                for r in 0..rows {
295                    for c in 0..sub_cols {
296                        sub_array[[r, c]] = array[[r, start + c]].clone();
297                    }
298                }
299
300                result.push(sub_array);
301            }
302        }
303        _ => return Err("Axis must be 0 or 1 for 2D arrays"),
304    }
305
306    Ok(result)
307}
308
309/// Take elements from a 2D array along a given axis using indices from another array
310///
311/// # Arguments
312///
313/// * `array` - The input 2D array
314/// * `indices` - Array of indices to take
315/// * `axis` - The axis along which to take values (0 for rows, 1 for columns)
316///
317/// # Returns
318///
319/// An array of values at the specified indices along the given axis
320///
321/// # Examples
322///
323/// ```
324/// use ndarray::array;
325/// use scirs2_core::ndarray_ext::take_2d;
326///
327/// let a = array![[1, 2, 3], [4, 5, 6]];
328/// let indices = array![0, 2];
329/// let result = take_2d(a.view(), indices.view(), 1).unwrap();
330/// assert_eq!(result.shape(), &[2, 2]);
331/// assert_eq!(result[[0, 0]], 1);
332/// assert_eq!(result[[0, 1]], 3);
333/// ```
334#[allow(dead_code)]
335pub fn take_2d<T>(
336    array: ArrayView<T, Ix2>,
337    indices: ArrayView<usize, Ix1>,
338    axis: usize,
339) -> Result<Array<T, Ix2>, &'static str>
340where
341    T: Clone + Default,
342{
343    let (rows, cols) = (array.shape()[0], array.shape()[1]);
344    let axis_len = if axis == 0 { rows } else { cols };
345
346    // Check that indices are within bounds
347    for &idx in indices.iter() {
348        if idx >= axis_len {
349            return Err("Index out of bounds");
350        }
351    }
352
353    // Create the result array with the appropriate shape
354    let (result_rows, result_cols) = match axis {
355        0 => (indices.len(), cols),
356        1 => (rows, indices.len()),
357        _ => return Err("Axis must be 0 or 1 for 2D arrays"),
358    };
359
360    let mut result = Array::<T, Ix2>::default((result_rows, result_cols));
361
362    // Fill the result array
363    match axis {
364        0 => {
365            // Take along rows
366            for (i, &idx) in indices.iter().enumerate() {
367                for j in 0..cols {
368                    result[[i, j]] = array[[idx, j]].clone();
369                }
370            }
371        }
372        1 => {
373            // Take along columns
374            for i in 0..rows {
375                for (j, &idx) in indices.iter().enumerate() {
376                    result[[i, j]] = array[[i, idx]].clone();
377                }
378            }
379        }
380        _ => unreachable!(),
381    }
382
383    Ok(result)
384}
385
386/// Filter an array using a boolean mask
387///
388/// # Arguments
389///
390/// * `array` - The input array
391/// * `mask` - Boolean mask of the same shape as the array
392///
393/// # Returns
394///
395/// A 1D array containing the elements where the mask is true
396///
397/// # Examples
398///
399/// ```ignore
400/// use ndarray::array;
401/// use scirs2_core::ndarray_ext::mask_select;
402///
403/// let a = array![[1, 2, 3], [4, 5, 6]];
404/// let mask = array![[true, false, true], [false, true, false]];
405/// let result = mask_select(a.view(), mask.view()).unwrap();
406/// assert_eq!(result.shape(), &[3]);
407/// assert_eq!(result[0], 1);
408/// assert_eq!(result[1], 3);
409/// assert_eq!(result[2], 5);
410/// ```
411#[allow(dead_code)]
412pub fn mask_select<T>(
413    array: ArrayView<T, Ix2>,
414    mask: ArrayView<bool, Ix2>,
415) -> Result<Array<T, Ix1>, &'static str>
416where
417    T: Clone + Default,
418{
419    // Check that the mask has the same shape as the array
420    if array.shape() != mask.shape() {
421        return Err("Mask shape must match array shape");
422    }
423
424    // Count the number of true values in the mask
425    let true_count = mask.iter().filter(|&&x| x).count();
426
427    // Create the result array
428    let mut result = Array::<T, Ix1>::default(true_count);
429
430    // Fill the result array with elements where the mask is true
431    let mut idx = 0;
432    for (val, &m) in array.iter().zip(mask.iter()) {
433        if m {
434            result[idx] = val.clone();
435            idx += 1;
436        }
437    }
438
439    Ok(result)
440}
441
442/// Index a 2D array with a list of index arrays (fancy indexing)
443///
444/// # Arguments
445///
446/// * `array` - The input 2D array
447/// * `row_indices` - Array of row indices
448/// * `col_indices` - Array of column indices
449///
450/// # Returns
451///
452/// A 1D array containing the elements at the specified indices
453///
454/// # Examples
455///
456/// ```ignore
457/// use ndarray::array;
458/// use scirs2_core::ndarray_ext::fancy_index_2d;
459///
460/// let a = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]];
461/// let row_indices = array![0, 2];
462/// let col_indices = array![0, 1];
463/// let result = fancy_index_2d(a.view(), row_indices.view(), col_indices.view()).unwrap();
464/// assert_eq!(result.shape(), &[2]);
465/// assert_eq!(result[0], 1);
466/// assert_eq!(result[1], 8);
467/// ```
468#[allow(dead_code)]
469pub fn fancy_index_2d<T>(
470    array: ArrayView<T, Ix2>,
471    row_indices: ArrayView<usize, Ix1>,
472    col_indices: ArrayView<usize, Ix1>,
473) -> Result<Array<T, Ix1>, &'static str>
474where
475    T: Clone + Default,
476{
477    // Check that all index arrays have the same length
478    let result_size = row_indices.len();
479    if col_indices.len() != result_size {
480        return Err("Row and column index arrays must have the same length");
481    }
482
483    let (rows, cols) = (array.shape()[0], array.shape()[1]);
484
485    // Check that _indices are within bounds
486    for &idx in row_indices.iter() {
487        if idx >= rows {
488            return Err("Row index out of bounds");
489        }
490    }
491
492    for &idx in col_indices.iter() {
493        if idx >= cols {
494            return Err("Column index out of bounds");
495        }
496    }
497
498    // Create the result array
499    let mut result = Array::<T, Ix1>::default(result_size);
500
501    // Fill the result array
502    for i in 0..result_size {
503        let row = row_indices[i];
504        let col = col_indices[i];
505        result[i] = array[[row, col]].clone();
506    }
507
508    Ok(result)
509}
510
511/// Select elements from an array where a condition is true
512///
513/// # Arguments
514///
515/// * `array` - The input array
516/// * `condition` - A function that takes a reference to an element and returns a bool
517///
518/// # Returns
519///
520/// A 1D array containing the elements where the condition is true
521///
522/// # Examples
523///
524/// ```ignore
525/// use ndarray::array;
526/// use scirs2_core::ndarray_ext::where_condition;
527///
528/// let a = array![[1, 2, 3], [4, 5, 6]];
529/// let result = where_condition(a.view(), |&x| x > 3).unwrap();
530/// assert_eq!(result.shape(), &[3]);
531/// assert_eq!(result[0], 4);
532/// assert_eq!(result[1], 5);
533/// assert_eq!(result[2], 6);
534/// ```
535#[allow(dead_code)]
536pub fn where_condition<T, F>(
537    array: ArrayView<T, Ix2>,
538    condition: F,
539) -> Result<Array<T, Ix1>, &'static str>
540where
541    T: Clone + Default,
542    F: Fn(&T) -> bool,
543{
544    // Build a boolean mask array based on the condition
545    let mask = array.map(condition);
546
547    // Use the mask_select function to select elements
548    mask_select(array, mask.view())
549}
550
551/// Check if two shapes are broadcast compatible
552///
553/// # Arguments
554///
555/// * `shape1` - First shape as a slice
556/// * `shape2` - Second shape as a slice
557///
558/// # Returns
559///
560/// `true` if the shapes are broadcast compatible, `false` otherwise
561///
562/// # Examples
563///
564/// ```ignore
565/// use scirs2_core::ndarray_ext::is_broadcast_compatible;
566///
567/// assert!(is_broadcast_compatible(&[2, 3], &[3]));
568/// // This example has dimensions that don't match (5 vs 3 in dimension 0)
569/// // and aren't broadcasting compatible (neither is 1)
570/// assert!(!is_broadcast_compatible(&[5, 1, 4], &[3, 1, 1]));
571/// assert!(!is_broadcast_compatible(&[2, 3], &[4]));
572/// ```
573#[allow(dead_code)]
574pub fn is_broadcast_compatible(shape1: &[usize], shape2: &[usize]) -> bool {
575    // Align shapes to have the same dimensionality by prepending with 1s
576    let max_dim = shape1.len().max(shape2.len());
577
578    // Fill in 1s for missing dimensions
579    let get_dim = |shape: &[usize], i: usize| -> usize {
580        let offset = max_dim - shape.len();
581        if i < offset {
582            1 // Implicit dimension of size 1
583        } else {
584            shape[i - offset]
585        }
586    };
587
588    // Check broadcasting rules for each dimension
589    for i in 0..max_dim {
590        let dim1 = get_dim(shape1, i);
591        let dim2 = get_dim(shape2, i);
592
593        // Dimensions must either be the same or one of them must be 1
594        if dim1 != dim2 && dim1 != 1 && dim2 != 1 {
595            return false;
596        }
597    }
598
599    true
600}
601
602/// Calculate the broadcasted shape from two input shapes
603///
604/// # Arguments
605///
606/// * `shape1` - First shape as a slice
607/// * `shape2` - Second shape as a slice
608///
609/// # Returns
610///
611/// The broadcasted shape as a `Vec<usize>`, or `None` if the shapes are incompatible
612#[allow(dead_code)]
613pub fn broadcastshape(shape1: &[usize], shape2: &[usize]) -> Option<Vec<usize>> {
614    if !is_broadcast_compatible(shape1, shape2) {
615        return None;
616    }
617
618    // Align shapes to have the same dimensionality
619    let max_dim = shape1.len().max(shape2.len());
620    let mut result = Vec::with_capacity(max_dim);
621
622    // Fill in 1s for missing dimensions
623    let get_dim = |shape: &[usize], i: usize| -> usize {
624        let offset = max_dim - shape.len();
625        if i < offset {
626            1 // Implicit dimension of size 1
627        } else {
628            shape[i - offset]
629        }
630    };
631
632    // Calculate the broadcasted shape
633    for i in 0..max_dim {
634        let dim1 = get_dim(shape1, i);
635        let dim2 = get_dim(shape2, i);
636
637        // The broadcasted dimension is the maximum of the two
638        result.push(dim1.max(dim2));
639    }
640
641    Some(result)
642}
643
644/// Broadcast a 1D array to a 2D shape by repeating it along the specified axis
645///
646/// # Arguments
647///
648/// * `array` - The input 1D array
649/// * `repeats` - Number of times to repeat the array
650/// * `axis` - Axis along which to repeat (0 for rows, 1 for columns)
651///
652/// # Returns
653///
654/// A 2D array with the input array repeated along the specified axis
655///
656/// # Examples
657///
658/// ```ignore
659/// use ndarray::array;
660/// use scirs2_core::ndarray_ext::broadcast_1d_to_2d;
661///
662/// let a = array![1, 2, 3];
663/// let b = broadcast_1d_to_2d(a.view(), 2, 0).unwrap();
664/// assert_eq!(b.shape(), &[2, 3]);
665/// assert_eq!(b[[0, 0]], 1);
666/// assert_eq!(b[[1, 0]], 1);
667/// ```
668#[allow(dead_code)]
669pub fn broadcast_1d_to_2d<T>(
670    array: ArrayView<T, Ix1>,
671    repeats: usize,
672    axis: usize,
673) -> Result<Array<T, Ix2>, &'static str>
674where
675    T: Clone + Default,
676{
677    let len = array.len();
678
679    // Create the result array with the appropriate shape
680    let (rows, cols) = match axis {
681        0 => (repeats, len), // Broadcast along rows
682        1 => (len, repeats), // Broadcast along columns
683        _ => return Err("Axis must be 0 or 1"),
684    };
685
686    let mut result = Array::<T, Ix2>::default((rows, cols));
687
688    // Fill the result array
689    match axis {
690        0 => {
691            // Broadcast along rows
692            for i in 0..repeats {
693                for j in 0..len {
694                    result[[i, j]] = array[j].clone();
695                }
696            }
697        }
698        1 => {
699            // Broadcast along columns
700            for i in 0..len {
701                for j in 0..repeats {
702                    result[[i, j]] = array[i].clone();
703                }
704            }
705        }
706        _ => unreachable!(),
707    }
708
709    Ok(result)
710}
711
712/// Apply an element-wise binary operation to two arrays with broadcasting
713///
714/// # Arguments
715///
716/// * `a` - First array (2D)
717/// * `b` - Second array (can be 1D or 2D)
718/// * `op` - Binary operation to apply to each pair of elements
719///
720/// # Returns
721///
722/// A 2D array containing the result of the operation applied element-wise
723///
724/// # Examples
725///
726/// ```ignore
727/// use ndarray::array;
728/// use scirs2_core::ndarray_ext::broadcast_apply;
729///
730/// let a = array![[1, 2, 3], [4, 5, 6]];
731/// let b = array![10, 20, 30];
732/// let result = broadcast_apply(a.view(), b.view(), |x, y| x + y).unwrap();
733/// assert_eq!(result.shape(), &[2, 3]);
734/// assert_eq!(result[[0, 0]], 11);
735/// assert_eq!(result[[1, 2]], 36);
736/// ```
737#[allow(dead_code)]
738pub fn broadcast_apply<T, R, F>(
739    a: ArrayView<T, Ix2>,
740    b: ArrayView<T, Ix1>,
741    op: F,
742) -> Result<Array<R, Ix2>, &'static str>
743where
744    T: Clone + Default,
745    R: Clone + Default,
746    F: Fn(&T, &T) -> R,
747{
748    let (a_rows, a_cols) = (a.shape()[0], a.shape()[1]);
749    let b_len = b.len();
750
751    // Check that the arrays are broadcast compatible
752    if a_cols != b_len {
753        return Err("Arrays are not broadcast compatible");
754    }
755
756    // Create the result array
757    let mut result = Array::<R, Ix2>::default((a_rows, a_cols));
758
759    // Apply the operation element-wise
760    for i in 0..a_rows {
761        for j in 0..a_cols {
762            result[[i, j]] = op(&a[[i, j]], &b[j]);
763        }
764    }
765
766    Ok(result)
767}
768
769#[cfg(test)]
770mod tests {
771    use super::*;
772    use ndarray::array;
773
774    #[test]
775    fn test_reshape_2d() {
776        let a = array![[1, 2], [3, 4]];
777        let b = reshape_2d(a.view(), (4, 1)).unwrap();
778        assert_eq!(b.shape(), &[4, 1]);
779        assert_eq!(b[[0, 0]], 1);
780        assert_eq!(b[[1, 0]], 2);
781        assert_eq!(b[[2, 0]], 3);
782        assert_eq!(b[[3, 0]], 4);
783
784        // Test invalid shape
785        let result = reshape_2d(a.view(), (3, 1));
786        assert!(result.is_err());
787    }
788
789    #[test]
790    fn test_stack_2d() {
791        let a = array![[1, 2], [3, 4]];
792        let b = array![[5, 6], [7, 8]];
793
794        // Stack vertically (along axis 0)
795        let c = stack_2d(&[a.view(), b.view()], 0).unwrap();
796        assert_eq!(c.shape(), &[4, 2]);
797        assert_eq!(c[[0, 0]], 1);
798        assert_eq!(c[[1, 0]], 3);
799        assert_eq!(c[[2, 0]], 5);
800        assert_eq!(c[[3, 0]], 7);
801
802        // Stack horizontally (along axis 1)
803        let d = stack_2d(&[a.view(), b.view()], 1).unwrap();
804        assert_eq!(d.shape(), &[2, 4]);
805        assert_eq!(d[[0, 0]], 1);
806        assert_eq!(d[[0, 1]], 2);
807        assert_eq!(d[[0, 2]], 5);
808        assert_eq!(d[[0, 3]], 6);
809    }
810
811    #[test]
812    fn test_transpose_2d() {
813        let a = array![[1, 2, 3], [4, 5, 6]];
814        let b = transpose_2d(a.view());
815        assert_eq!(b.shape(), &[3, 2]);
816        assert_eq!(b[[0, 0]], 1);
817        assert_eq!(b[[0, 1]], 4);
818        assert_eq!(b[[1, 0]], 2);
819        assert_eq!(b[[2, 1]], 6);
820    }
821
822    #[test]
823    fn test_split_2d() {
824        let a = array![[1, 2, 3, 4], [5, 6, 7, 8]];
825
826        // Split along columns at index 2
827        let result = split_2d(a.view(), &[2], 1).unwrap();
828        assert_eq!(result.len(), 2);
829        assert_eq!(result[0].shape(), &[2, 2]);
830        assert_eq!(result[0][[0, 0]], 1);
831        assert_eq!(result[0][[0, 1]], 2);
832        assert_eq!(result[0][[1, 0]], 5);
833        assert_eq!(result[0][[1, 1]], 6);
834        assert_eq!(result[1].shape(), &[2, 2]);
835        assert_eq!(result[1][[0, 0]], 3);
836        assert_eq!(result[1][[0, 1]], 4);
837        assert_eq!(result[1][[1, 0]], 7);
838        assert_eq!(result[1][[1, 1]], 8);
839
840        // Split along rows at index 1
841        let result = split_2d(a.view(), &[1], 0).unwrap();
842        assert_eq!(result.len(), 2);
843        assert_eq!(result[0].shape(), &[1, 4]);
844        assert_eq!(result[1].shape(), &[1, 4]);
845    }
846
847    #[test]
848    fn test_take_2d() {
849        let a = array![[1, 2, 3], [4, 5, 6]];
850        let indices = array![0, 2];
851
852        // Take along columns
853        let result = take_2d(a.view(), indices.view(), 1).unwrap();
854        assert_eq!(result.shape(), &[2, 2]);
855        assert_eq!(result[[0, 0]], 1);
856        assert_eq!(result[[0, 1]], 3);
857        assert_eq!(result[[1, 0]], 4);
858        assert_eq!(result[[1, 1]], 6);
859    }
860
861    #[test]
862    fn test_mask_select() {
863        let a = array![[1, 2, 3], [4, 5, 6]];
864        let mask = array![[true, false, true], [false, true, false]];
865
866        let result = mask_select(a.view(), mask.view()).unwrap();
867        assert_eq!(result.shape(), &[3]);
868        assert_eq!(result[0], 1);
869        assert_eq!(result[1], 3);
870        assert_eq!(result[2], 5);
871    }
872
873    #[test]
874    fn test_fancy_index_2d() {
875        let a = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]];
876        let row_indices = array![0, 2];
877        let col_indices = array![0, 1];
878
879        let result = fancy_index_2d(a.view(), row_indices.view(), col_indices.view()).unwrap();
880        assert_eq!(result.shape(), &[2]);
881        assert_eq!(result[0], 1);
882        assert_eq!(result[1], 8);
883    }
884
885    #[test]
886    fn test_where_condition() {
887        let a = array![[1, 2, 3], [4, 5, 6]];
888        let result = where_condition(a.view(), |&x| x > 3).unwrap();
889        assert_eq!(result.shape(), &[3]);
890        assert_eq!(result[0], 4);
891        assert_eq!(result[1], 5);
892        assert_eq!(result[2], 6);
893    }
894
895    #[test]
896    fn test_broadcast_1d_to_2d() {
897        let a = array![1, 2, 3];
898
899        // Broadcast along rows (axis 0)
900        let b = broadcast_1d_to_2d(a.view(), 2, 0).unwrap();
901        assert_eq!(b.shape(), &[2, 3]);
902        assert_eq!(b[[0, 0]], 1);
903        assert_eq!(b[[0, 1]], 2);
904        assert_eq!(b[[1, 0]], 1);
905        assert_eq!(b[[1, 2]], 3);
906
907        // Broadcast along columns (axis 1)
908        let c = broadcast_1d_to_2d(a.view(), 2, 1).unwrap();
909        assert_eq!(c.shape(), &[3, 2]);
910        assert_eq!(c[[0, 0]], 1);
911        assert_eq!(c[[0, 1]], 1);
912        assert_eq!(c[[1, 0]], 2);
913        assert_eq!(c[[2, 1]], 3);
914    }
915
916    #[test]
917    fn test_broadcast_apply() {
918        let a = array![[1, 2, 3], [4, 5, 6]];
919        let b = array![10, 20, 30];
920
921        let result = broadcast_apply(a.view(), b.view(), |x, y| x + y).unwrap();
922        assert_eq!(result.shape(), &[2, 3]);
923        assert_eq!(result[[0, 0]], 11);
924        assert_eq!(result[[0, 1]], 22);
925        assert_eq!(result[[0, 2]], 33);
926        assert_eq!(result[[1, 0]], 14);
927        assert_eq!(result[[1, 1]], 25);
928        assert_eq!(result[[1, 2]], 36);
929
930        let result = broadcast_apply(a.view(), b.view(), |x, y| x * y).unwrap();
931        assert_eq!(result.shape(), &[2, 3]);
932        assert_eq!(result[[0, 0]], 10);
933        assert_eq!(result[[0, 1]], 40);
934        assert_eq!(result[[0, 2]], 90);
935        assert_eq!(result[[1, 0]], 40);
936        assert_eq!(result[[1, 1]], 100);
937        assert_eq!(result[[1, 2]], 180);
938    }
939}