scirs2_core/ndarray_ext/mod.rs
1//! Extended ndarray operations for scientific computing
2//!
3//! This module provides additional functionality for ndarray to support
4//! the advanced array operations necessary for a complete SciPy port.
5//! It implements core `NumPy`-like features that are not available in the
6//! base ndarray crate.
7
8/// Re-export essential ndarray types and macros for convenience
9pub use ndarray::{
10 s, Array, ArrayBase, ArrayView, ArrayViewMut, Axis, Data, DataMut, Dim, Dimension, Ix1, Ix2,
11 Ix3, Ix4, Ix5, Ix6, IxDyn, OwnedRepr, RemoveAxis, ScalarOperand, ShapeBuilder, ShapeError,
12 SliceInfo, ViewRepr, Zip,
13};
14
15// Re-export the array! macro for convenient array creation
16// This addresses the common issue where users expect array! to be available in scirs2-core
17// instead of requiring import from scirs2_autograd::ndarray
18pub use ndarray::{arr1, arr2, array};
19
20// Re-export type aliases for common array sizes
21pub use ndarray::{ArcArray1, ArcArray2};
22pub use ndarray::{Array0, Array1, Array2, Array3, Array4, Array5, Array6, ArrayD};
23pub use ndarray::{
24 ArrayView0, ArrayView1, ArrayView2, ArrayView3, ArrayView4, ArrayView5, ArrayView6, ArrayViewD,
25};
26pub use ndarray::{
27 ArrayViewMut0, ArrayViewMut1, ArrayViewMut2, ArrayViewMut3, ArrayViewMut4, ArrayViewMut5,
28 ArrayViewMut6, ArrayViewMutD,
29};
30
31/// Advanced indexing operations (`NumPy`-like boolean masking, fancy indexing, etc.)
32pub mod indexing;
33
34/// Statistical functions for ndarray arrays (mean, median, variance, correlation, etc.)
35pub mod stats;
36
37/// Matrix operations (eye, diag, kron, etc.)
38pub mod matrix;
39
40/// Array manipulation operations (flip, roll, tile, repeat, etc.)
41pub mod manipulation;
42
43/// Random array generation with full ndarray-rand compatibility and scientific computing extensions
44#[cfg(feature = "random")]
45pub mod random;
46
47/// Reshape a 2D array to a new shape without copying data when possible
48///
49/// # Arguments
50///
51/// * `array` - The input array
52/// * `shape` - The new shape (rows, cols), which must be compatible with the original shape
53///
54/// # Returns
55///
56/// A reshaped array view of the input array if possible, or a new array otherwise
57///
58/// # Examples
59///
60/// ```
61/// use ndarray::array;
62/// use scirs2_core::ndarray_ext::reshape_2d;
63///
64/// let a = array![[1, 2], [3, 4]];
65/// let b = reshape_2d(a.view(), (4, 1)).unwrap();
66/// assert_eq!(b.shape(), &[4, 1]);
67/// assert_eq!(b[[0, 0]], 1);
68/// assert_eq!(b[[3, 0]], 4);
69/// ```
70#[allow(dead_code)]
71pub fn reshape_2d<T>(
72 array: ArrayView<T, Ix2>,
73 shape: (usize, usize),
74) -> Result<Array<T, Ix2>, &'static str>
75where
76 T: Clone + Default,
77{
78 let (rows, cols) = shape;
79 let total_elements = rows * cols;
80
81 // Check if the new shape is compatible with the original shape
82 if total_elements != array.len() {
83 return Err("New shape dimensions must match the total number of elements");
84 }
85
86 // Create a new array with the specified shape
87 let mut result = Array::<T, Ix2>::default(shape);
88
89 // Fill the result array with elements from the input array
90 let flat_iter = array.iter();
91 for (i, val) in flat_iter.enumerate() {
92 let r = i / cols;
93 let c = i % cols;
94 result[[r, c]] = val.clone();
95 }
96
97 Ok(result)
98}
99
100/// Stack 2D arrays along a given axis
101///
102/// # Arguments
103///
104/// * `arrays` - A slice of 2D arrays to stack
105/// * `axis` - The axis along which to stack (0 for rows, 1 for columns)
106///
107/// # Returns
108///
109/// A new array containing the stacked arrays
110///
111/// # Examples
112///
113/// ```
114/// use ndarray::{array, Axis};
115/// use scirs2_core::ndarray_ext::stack_2d;
116///
117/// let a = array![[1, 2], [3, 4]];
118/// let b = array![[5, 6], [7, 8]];
119/// let c = stack_2d(&[a.view(), b.view()], 0).unwrap();
120/// assert_eq!(c.shape(), &[4, 2]);
121/// ```
122#[allow(dead_code)]
123pub fn stack_2d<T>(arrays: &[ArrayView<T, Ix2>], axis: usize) -> Result<Array<T, Ix2>, &'static str>
124where
125 T: Clone + Default,
126{
127 if arrays.is_empty() {
128 return Err("No _arrays provided for stacking");
129 }
130
131 // Validate that all _arrays have the same shape
132 let firstshape = arrays[0].shape();
133 for array in arrays.iter().skip(1) {
134 if array.shape() != firstshape {
135 return Err("All _arrays must have the same shape for stacking");
136 }
137 }
138
139 let (rows, cols) = (firstshape[0], firstshape[1]);
140
141 // Calculate the new shape
142 let (new_rows, new_cols) = match axis {
143 0 => (rows * arrays.len(), cols), // Stack vertically
144 1 => (rows, cols * arrays.len()), // Stack horizontally
145 _ => return Err("Axis must be 0 or 1 for 2D _arrays"),
146 };
147
148 // Create a new array to hold the stacked result
149 let mut result = Array::<T, Ix2>::default((new_rows, new_cols));
150
151 // Copy data from the input _arrays to the result
152 match axis {
153 0 => {
154 // Stack vertically (along rows)
155 for (array_idx, array) in arrays.iter().enumerate() {
156 let start_row = array_idx * rows;
157 for r in 0..rows {
158 for c in 0..cols {
159 result[[start_row + r, c]] = array[[r, c]].clone();
160 }
161 }
162 }
163 }
164 1 => {
165 // Stack horizontally (along columns)
166 for (array_idx, array) in arrays.iter().enumerate() {
167 let start_col = array_idx * cols;
168 for r in 0..rows {
169 for c in 0..cols {
170 result[[r, start_col + c]] = array[[r, c]].clone();
171 }
172 }
173 }
174 }
175 _ => unreachable!(),
176 }
177
178 Ok(result)
179}
180
181/// Swap axes (transpose) of a 2D array
182///
183/// # Arguments
184///
185/// * `array` - The input 2D array
186///
187/// # Returns
188///
189/// A view of the input array with the axes swapped
190///
191/// # Examples
192///
193/// ```ignore
194/// use ndarray::array;
195/// use scirs2_core::ndarray_ext::transpose_2d;
196///
197/// let a = array![[1, 2, 3], [4, 5, 6]];
198/// let b = transpose_2d(a.view());
199/// assert_eq!(b.shape(), &[3, 2]);
200/// assert_eq!(b[[0, 0]], 1);
201/// assert_eq!(b[[0, 1]], 4);
202/// ```
203#[allow(dead_code)]
204pub fn transpose_2d<T>(array: ArrayView<T, Ix2>) -> Array<T, Ix2>
205where
206 T: Clone,
207{
208 array.t().to_owned()
209}
210
211/// Split a 2D array into multiple sub-arrays along a given axis
212///
213/// # Arguments
214///
215/// * `array` - The input 2D array to split
216/// * `indices` - Indices where the array should be split
217/// * `axis` - The axis along which to split (0 for rows, 1 for columns)
218///
219/// # Returns
220///
221/// A vector of arrays resulting from the split
222///
223/// # Examples
224///
225/// ```
226/// use ndarray::array;
227/// use scirs2_core::ndarray_ext::split_2d;
228///
229/// let a = array![[1, 2, 3, 4], [5, 6, 7, 8]];
230/// let result = split_2d(a.view(), &[2], 1).unwrap();
231/// assert_eq!(result.len(), 2);
232/// assert_eq!(result[0].shape(), &[2, 2]);
233/// assert_eq!(result[1].shape(), &[2, 2]);
234/// ```
235#[allow(dead_code)]
236pub fn split_2d<T>(
237 array: ArrayView<T, Ix2>,
238 indices: &[usize],
239 axis: usize,
240) -> Result<Vec<Array<T, Ix2>>, &'static str>
241where
242 T: Clone + Default,
243{
244 if indices.is_empty() {
245 return Ok(vec![array.to_owned()]);
246 }
247
248 let (rows, cols) = (array.shape()[0], array.shape()[1]);
249 let axis_len = if axis == 0 { rows } else { cols };
250
251 // Validate indices
252 for &idx in indices {
253 if idx >= axis_len {
254 return Err("Split index out of bounds");
255 }
256 }
257
258 // Sort indices to ensure they're in ascending order
259 let mut sorted_indices = indices.to_vec();
260 sorted_indices.sort_unstable();
261
262 // Calculate the sub-array boundaries
263 let mut starts = vec![0];
264 starts.extend_from_slice(&sorted_indices);
265
266 let mut ends = sorted_indices.clone();
267 ends.push(axis_len);
268
269 // Create the split sub-arrays
270 let mut result = Vec::with_capacity(starts.len());
271
272 match axis {
273 0 => {
274 // Split along rows
275 for (&start, &end) in starts.iter().zip(ends.iter()) {
276 let sub_rows = end - start;
277 let mut sub_array = Array::<T, Ix2>::default((sub_rows, cols));
278
279 for r in 0..sub_rows {
280 for c in 0..cols {
281 sub_array[[r, c]] = array[[start + r, c]].clone();
282 }
283 }
284
285 result.push(sub_array);
286 }
287 }
288 1 => {
289 // Split along columns
290 for (&start, &end) in starts.iter().zip(ends.iter()) {
291 let sub_cols = end - start;
292 let mut sub_array = Array::<T, Ix2>::default((rows, sub_cols));
293
294 for r in 0..rows {
295 for c in 0..sub_cols {
296 sub_array[[r, c]] = array[[r, start + c]].clone();
297 }
298 }
299
300 result.push(sub_array);
301 }
302 }
303 _ => return Err("Axis must be 0 or 1 for 2D arrays"),
304 }
305
306 Ok(result)
307}
308
309/// Take elements from a 2D array along a given axis using indices from another array
310///
311/// # Arguments
312///
313/// * `array` - The input 2D array
314/// * `indices` - Array of indices to take
315/// * `axis` - The axis along which to take values (0 for rows, 1 for columns)
316///
317/// # Returns
318///
319/// An array of values at the specified indices along the given axis
320///
321/// # Examples
322///
323/// ```
324/// use ndarray::array;
325/// use scirs2_core::ndarray_ext::take_2d;
326///
327/// let a = array![[1, 2, 3], [4, 5, 6]];
328/// let indices = array![0, 2];
329/// let result = take_2d(a.view(), indices.view(), 1).unwrap();
330/// assert_eq!(result.shape(), &[2, 2]);
331/// assert_eq!(result[[0, 0]], 1);
332/// assert_eq!(result[[0, 1]], 3);
333/// ```
334#[allow(dead_code)]
335pub fn take_2d<T>(
336 array: ArrayView<T, Ix2>,
337 indices: ArrayView<usize, Ix1>,
338 axis: usize,
339) -> Result<Array<T, Ix2>, &'static str>
340where
341 T: Clone + Default,
342{
343 let (rows, cols) = (array.shape()[0], array.shape()[1]);
344 let axis_len = if axis == 0 { rows } else { cols };
345
346 // Check that indices are within bounds
347 for &idx in indices.iter() {
348 if idx >= axis_len {
349 return Err("Index out of bounds");
350 }
351 }
352
353 // Create the result array with the appropriate shape
354 let (result_rows, result_cols) = match axis {
355 0 => (indices.len(), cols),
356 1 => (rows, indices.len()),
357 _ => return Err("Axis must be 0 or 1 for 2D arrays"),
358 };
359
360 let mut result = Array::<T, Ix2>::default((result_rows, result_cols));
361
362 // Fill the result array
363 match axis {
364 0 => {
365 // Take along rows
366 for (i, &idx) in indices.iter().enumerate() {
367 for j in 0..cols {
368 result[[i, j]] = array[[idx, j]].clone();
369 }
370 }
371 }
372 1 => {
373 // Take along columns
374 for i in 0..rows {
375 for (j, &idx) in indices.iter().enumerate() {
376 result[[i, j]] = array[[i, idx]].clone();
377 }
378 }
379 }
380 _ => unreachable!(),
381 }
382
383 Ok(result)
384}
385
386/// Filter an array using a boolean mask
387///
388/// # Arguments
389///
390/// * `array` - The input array
391/// * `mask` - Boolean mask of the same shape as the array
392///
393/// # Returns
394///
395/// A 1D array containing the elements where the mask is true
396///
397/// # Examples
398///
399/// ```ignore
400/// use ndarray::array;
401/// use scirs2_core::ndarray_ext::mask_select;
402///
403/// let a = array![[1, 2, 3], [4, 5, 6]];
404/// let mask = array![[true, false, true], [false, true, false]];
405/// let result = mask_select(a.view(), mask.view()).unwrap();
406/// assert_eq!(result.shape(), &[3]);
407/// assert_eq!(result[0], 1);
408/// assert_eq!(result[1], 3);
409/// assert_eq!(result[2], 5);
410/// ```
411#[allow(dead_code)]
412pub fn mask_select<T>(
413 array: ArrayView<T, Ix2>,
414 mask: ArrayView<bool, Ix2>,
415) -> Result<Array<T, Ix1>, &'static str>
416where
417 T: Clone + Default,
418{
419 // Check that the mask has the same shape as the array
420 if array.shape() != mask.shape() {
421 return Err("Mask shape must match array shape");
422 }
423
424 // Count the number of true values in the mask
425 let true_count = mask.iter().filter(|&&x| x).count();
426
427 // Create the result array
428 let mut result = Array::<T, Ix1>::default(true_count);
429
430 // Fill the result array with elements where the mask is true
431 let mut idx = 0;
432 for (val, &m) in array.iter().zip(mask.iter()) {
433 if m {
434 result[idx] = val.clone();
435 idx += 1;
436 }
437 }
438
439 Ok(result)
440}
441
442/// Index a 2D array with a list of index arrays (fancy indexing)
443///
444/// # Arguments
445///
446/// * `array` - The input 2D array
447/// * `row_indices` - Array of row indices
448/// * `col_indices` - Array of column indices
449///
450/// # Returns
451///
452/// A 1D array containing the elements at the specified indices
453///
454/// # Examples
455///
456/// ```ignore
457/// use ndarray::array;
458/// use scirs2_core::ndarray_ext::fancy_index_2d;
459///
460/// let a = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]];
461/// let row_indices = array![0, 2];
462/// let col_indices = array![0, 1];
463/// let result = fancy_index_2d(a.view(), row_indices.view(), col_indices.view()).unwrap();
464/// assert_eq!(result.shape(), &[2]);
465/// assert_eq!(result[0], 1);
466/// assert_eq!(result[1], 8);
467/// ```
468#[allow(dead_code)]
469pub fn fancy_index_2d<T>(
470 array: ArrayView<T, Ix2>,
471 row_indices: ArrayView<usize, Ix1>,
472 col_indices: ArrayView<usize, Ix1>,
473) -> Result<Array<T, Ix1>, &'static str>
474where
475 T: Clone + Default,
476{
477 // Check that all index arrays have the same length
478 let result_size = row_indices.len();
479 if col_indices.len() != result_size {
480 return Err("Row and column index arrays must have the same length");
481 }
482
483 let (rows, cols) = (array.shape()[0], array.shape()[1]);
484
485 // Check that _indices are within bounds
486 for &idx in row_indices.iter() {
487 if idx >= rows {
488 return Err("Row index out of bounds");
489 }
490 }
491
492 for &idx in col_indices.iter() {
493 if idx >= cols {
494 return Err("Column index out of bounds");
495 }
496 }
497
498 // Create the result array
499 let mut result = Array::<T, Ix1>::default(result_size);
500
501 // Fill the result array
502 for i in 0..result_size {
503 let row = row_indices[i];
504 let col = col_indices[i];
505 result[i] = array[[row, col]].clone();
506 }
507
508 Ok(result)
509}
510
511/// Select elements from an array where a condition is true
512///
513/// # Arguments
514///
515/// * `array` - The input array
516/// * `condition` - A function that takes a reference to an element and returns a bool
517///
518/// # Returns
519///
520/// A 1D array containing the elements where the condition is true
521///
522/// # Examples
523///
524/// ```ignore
525/// use ndarray::array;
526/// use scirs2_core::ndarray_ext::where_condition;
527///
528/// let a = array![[1, 2, 3], [4, 5, 6]];
529/// let result = where_condition(a.view(), |&x| x > 3).unwrap();
530/// assert_eq!(result.shape(), &[3]);
531/// assert_eq!(result[0], 4);
532/// assert_eq!(result[1], 5);
533/// assert_eq!(result[2], 6);
534/// ```
535#[allow(dead_code)]
536pub fn where_condition<T, F>(
537 array: ArrayView<T, Ix2>,
538 condition: F,
539) -> Result<Array<T, Ix1>, &'static str>
540where
541 T: Clone + Default,
542 F: Fn(&T) -> bool,
543{
544 // Build a boolean mask array based on the condition
545 let mask = array.map(condition);
546
547 // Use the mask_select function to select elements
548 mask_select(array, mask.view())
549}
550
551/// Check if two shapes are broadcast compatible
552///
553/// # Arguments
554///
555/// * `shape1` - First shape as a slice
556/// * `shape2` - Second shape as a slice
557///
558/// # Returns
559///
560/// `true` if the shapes are broadcast compatible, `false` otherwise
561///
562/// # Examples
563///
564/// ```ignore
565/// use scirs2_core::ndarray_ext::is_broadcast_compatible;
566///
567/// assert!(is_broadcast_compatible(&[2, 3], &[3]));
568/// // This example has dimensions that don't match (5 vs 3 in dimension 0)
569/// // and aren't broadcasting compatible (neither is 1)
570/// assert!(!is_broadcast_compatible(&[5, 1, 4], &[3, 1, 1]));
571/// assert!(!is_broadcast_compatible(&[2, 3], &[4]));
572/// ```
573#[allow(dead_code)]
574pub fn is_broadcast_compatible(shape1: &[usize], shape2: &[usize]) -> bool {
575 // Align shapes to have the same dimensionality by prepending with 1s
576 let max_dim = shape1.len().max(shape2.len());
577
578 // Fill in 1s for missing dimensions
579 let get_dim = |shape: &[usize], i: usize| -> usize {
580 let offset = max_dim - shape.len();
581 if i < offset {
582 1 // Implicit dimension of size 1
583 } else {
584 shape[i - offset]
585 }
586 };
587
588 // Check broadcasting rules for each dimension
589 for i in 0..max_dim {
590 let dim1 = get_dim(shape1, i);
591 let dim2 = get_dim(shape2, i);
592
593 // Dimensions must either be the same or one of them must be 1
594 if dim1 != dim2 && dim1 != 1 && dim2 != 1 {
595 return false;
596 }
597 }
598
599 true
600}
601
602/// Calculate the broadcasted shape from two input shapes
603///
604/// # Arguments
605///
606/// * `shape1` - First shape as a slice
607/// * `shape2` - Second shape as a slice
608///
609/// # Returns
610///
611/// The broadcasted shape as a `Vec<usize>`, or `None` if the shapes are incompatible
612#[allow(dead_code)]
613pub fn broadcastshape(shape1: &[usize], shape2: &[usize]) -> Option<Vec<usize>> {
614 if !is_broadcast_compatible(shape1, shape2) {
615 return None;
616 }
617
618 // Align shapes to have the same dimensionality
619 let max_dim = shape1.len().max(shape2.len());
620 let mut result = Vec::with_capacity(max_dim);
621
622 // Fill in 1s for missing dimensions
623 let get_dim = |shape: &[usize], i: usize| -> usize {
624 let offset = max_dim - shape.len();
625 if i < offset {
626 1 // Implicit dimension of size 1
627 } else {
628 shape[i - offset]
629 }
630 };
631
632 // Calculate the broadcasted shape
633 for i in 0..max_dim {
634 let dim1 = get_dim(shape1, i);
635 let dim2 = get_dim(shape2, i);
636
637 // The broadcasted dimension is the maximum of the two
638 result.push(dim1.max(dim2));
639 }
640
641 Some(result)
642}
643
644/// Broadcast a 1D array to a 2D shape by repeating it along the specified axis
645///
646/// # Arguments
647///
648/// * `array` - The input 1D array
649/// * `repeats` - Number of times to repeat the array
650/// * `axis` - Axis along which to repeat (0 for rows, 1 for columns)
651///
652/// # Returns
653///
654/// A 2D array with the input array repeated along the specified axis
655///
656/// # Examples
657///
658/// ```ignore
659/// use ndarray::array;
660/// use scirs2_core::ndarray_ext::broadcast_1d_to_2d;
661///
662/// let a = array![1, 2, 3];
663/// let b = broadcast_1d_to_2d(a.view(), 2, 0).unwrap();
664/// assert_eq!(b.shape(), &[2, 3]);
665/// assert_eq!(b[[0, 0]], 1);
666/// assert_eq!(b[[1, 0]], 1);
667/// ```
668#[allow(dead_code)]
669pub fn broadcast_1d_to_2d<T>(
670 array: ArrayView<T, Ix1>,
671 repeats: usize,
672 axis: usize,
673) -> Result<Array<T, Ix2>, &'static str>
674where
675 T: Clone + Default,
676{
677 let len = array.len();
678
679 // Create the result array with the appropriate shape
680 let (rows, cols) = match axis {
681 0 => (repeats, len), // Broadcast along rows
682 1 => (len, repeats), // Broadcast along columns
683 _ => return Err("Axis must be 0 or 1"),
684 };
685
686 let mut result = Array::<T, Ix2>::default((rows, cols));
687
688 // Fill the result array
689 match axis {
690 0 => {
691 // Broadcast along rows
692 for i in 0..repeats {
693 for j in 0..len {
694 result[[i, j]] = array[j].clone();
695 }
696 }
697 }
698 1 => {
699 // Broadcast along columns
700 for i in 0..len {
701 for j in 0..repeats {
702 result[[i, j]] = array[i].clone();
703 }
704 }
705 }
706 _ => unreachable!(),
707 }
708
709 Ok(result)
710}
711
712/// Apply an element-wise binary operation to two arrays with broadcasting
713///
714/// # Arguments
715///
716/// * `a` - First array (2D)
717/// * `b` - Second array (can be 1D or 2D)
718/// * `op` - Binary operation to apply to each pair of elements
719///
720/// # Returns
721///
722/// A 2D array containing the result of the operation applied element-wise
723///
724/// # Examples
725///
726/// ```ignore
727/// use ndarray::array;
728/// use scirs2_core::ndarray_ext::broadcast_apply;
729///
730/// let a = array![[1, 2, 3], [4, 5, 6]];
731/// let b = array![10, 20, 30];
732/// let result = broadcast_apply(a.view(), b.view(), |x, y| x + y).unwrap();
733/// assert_eq!(result.shape(), &[2, 3]);
734/// assert_eq!(result[[0, 0]], 11);
735/// assert_eq!(result[[1, 2]], 36);
736/// ```
737#[allow(dead_code)]
738pub fn broadcast_apply<T, R, F>(
739 a: ArrayView<T, Ix2>,
740 b: ArrayView<T, Ix1>,
741 op: F,
742) -> Result<Array<R, Ix2>, &'static str>
743where
744 T: Clone + Default,
745 R: Clone + Default,
746 F: Fn(&T, &T) -> R,
747{
748 let (a_rows, a_cols) = (a.shape()[0], a.shape()[1]);
749 let b_len = b.len();
750
751 // Check that the arrays are broadcast compatible
752 if a_cols != b_len {
753 return Err("Arrays are not broadcast compatible");
754 }
755
756 // Create the result array
757 let mut result = Array::<R, Ix2>::default((a_rows, a_cols));
758
759 // Apply the operation element-wise
760 for i in 0..a_rows {
761 for j in 0..a_cols {
762 result[[i, j]] = op(&a[[i, j]], &b[j]);
763 }
764 }
765
766 Ok(result)
767}
768
769#[cfg(test)]
770mod tests {
771 use super::*;
772 use ndarray::array;
773
774 #[test]
775 fn test_reshape_2d() {
776 let a = array![[1, 2], [3, 4]];
777 let b = reshape_2d(a.view(), (4, 1)).unwrap();
778 assert_eq!(b.shape(), &[4, 1]);
779 assert_eq!(b[[0, 0]], 1);
780 assert_eq!(b[[1, 0]], 2);
781 assert_eq!(b[[2, 0]], 3);
782 assert_eq!(b[[3, 0]], 4);
783
784 // Test invalid shape
785 let result = reshape_2d(a.view(), (3, 1));
786 assert!(result.is_err());
787 }
788
789 #[test]
790 fn test_stack_2d() {
791 let a = array![[1, 2], [3, 4]];
792 let b = array![[5, 6], [7, 8]];
793
794 // Stack vertically (along axis 0)
795 let c = stack_2d(&[a.view(), b.view()], 0).unwrap();
796 assert_eq!(c.shape(), &[4, 2]);
797 assert_eq!(c[[0, 0]], 1);
798 assert_eq!(c[[1, 0]], 3);
799 assert_eq!(c[[2, 0]], 5);
800 assert_eq!(c[[3, 0]], 7);
801
802 // Stack horizontally (along axis 1)
803 let d = stack_2d(&[a.view(), b.view()], 1).unwrap();
804 assert_eq!(d.shape(), &[2, 4]);
805 assert_eq!(d[[0, 0]], 1);
806 assert_eq!(d[[0, 1]], 2);
807 assert_eq!(d[[0, 2]], 5);
808 assert_eq!(d[[0, 3]], 6);
809 }
810
811 #[test]
812 fn test_transpose_2d() {
813 let a = array![[1, 2, 3], [4, 5, 6]];
814 let b = transpose_2d(a.view());
815 assert_eq!(b.shape(), &[3, 2]);
816 assert_eq!(b[[0, 0]], 1);
817 assert_eq!(b[[0, 1]], 4);
818 assert_eq!(b[[1, 0]], 2);
819 assert_eq!(b[[2, 1]], 6);
820 }
821
822 #[test]
823 fn test_split_2d() {
824 let a = array![[1, 2, 3, 4], [5, 6, 7, 8]];
825
826 // Split along columns at index 2
827 let result = split_2d(a.view(), &[2], 1).unwrap();
828 assert_eq!(result.len(), 2);
829 assert_eq!(result[0].shape(), &[2, 2]);
830 assert_eq!(result[0][[0, 0]], 1);
831 assert_eq!(result[0][[0, 1]], 2);
832 assert_eq!(result[0][[1, 0]], 5);
833 assert_eq!(result[0][[1, 1]], 6);
834 assert_eq!(result[1].shape(), &[2, 2]);
835 assert_eq!(result[1][[0, 0]], 3);
836 assert_eq!(result[1][[0, 1]], 4);
837 assert_eq!(result[1][[1, 0]], 7);
838 assert_eq!(result[1][[1, 1]], 8);
839
840 // Split along rows at index 1
841 let result = split_2d(a.view(), &[1], 0).unwrap();
842 assert_eq!(result.len(), 2);
843 assert_eq!(result[0].shape(), &[1, 4]);
844 assert_eq!(result[1].shape(), &[1, 4]);
845 }
846
847 #[test]
848 fn test_take_2d() {
849 let a = array![[1, 2, 3], [4, 5, 6]];
850 let indices = array![0, 2];
851
852 // Take along columns
853 let result = take_2d(a.view(), indices.view(), 1).unwrap();
854 assert_eq!(result.shape(), &[2, 2]);
855 assert_eq!(result[[0, 0]], 1);
856 assert_eq!(result[[0, 1]], 3);
857 assert_eq!(result[[1, 0]], 4);
858 assert_eq!(result[[1, 1]], 6);
859 }
860
861 #[test]
862 fn test_mask_select() {
863 let a = array![[1, 2, 3], [4, 5, 6]];
864 let mask = array![[true, false, true], [false, true, false]];
865
866 let result = mask_select(a.view(), mask.view()).unwrap();
867 assert_eq!(result.shape(), &[3]);
868 assert_eq!(result[0], 1);
869 assert_eq!(result[1], 3);
870 assert_eq!(result[2], 5);
871 }
872
873 #[test]
874 fn test_fancy_index_2d() {
875 let a = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]];
876 let row_indices = array![0, 2];
877 let col_indices = array![0, 1];
878
879 let result = fancy_index_2d(a.view(), row_indices.view(), col_indices.view()).unwrap();
880 assert_eq!(result.shape(), &[2]);
881 assert_eq!(result[0], 1);
882 assert_eq!(result[1], 8);
883 }
884
885 #[test]
886 fn test_where_condition() {
887 let a = array![[1, 2, 3], [4, 5, 6]];
888 let result = where_condition(a.view(), |&x| x > 3).unwrap();
889 assert_eq!(result.shape(), &[3]);
890 assert_eq!(result[0], 4);
891 assert_eq!(result[1], 5);
892 assert_eq!(result[2], 6);
893 }
894
895 #[test]
896 fn test_broadcast_1d_to_2d() {
897 let a = array![1, 2, 3];
898
899 // Broadcast along rows (axis 0)
900 let b = broadcast_1d_to_2d(a.view(), 2, 0).unwrap();
901 assert_eq!(b.shape(), &[2, 3]);
902 assert_eq!(b[[0, 0]], 1);
903 assert_eq!(b[[0, 1]], 2);
904 assert_eq!(b[[1, 0]], 1);
905 assert_eq!(b[[1, 2]], 3);
906
907 // Broadcast along columns (axis 1)
908 let c = broadcast_1d_to_2d(a.view(), 2, 1).unwrap();
909 assert_eq!(c.shape(), &[3, 2]);
910 assert_eq!(c[[0, 0]], 1);
911 assert_eq!(c[[0, 1]], 1);
912 assert_eq!(c[[1, 0]], 2);
913 assert_eq!(c[[2, 1]], 3);
914 }
915
916 #[test]
917 fn test_broadcast_apply() {
918 let a = array![[1, 2, 3], [4, 5, 6]];
919 let b = array![10, 20, 30];
920
921 let result = broadcast_apply(a.view(), b.view(), |x, y| x + y).unwrap();
922 assert_eq!(result.shape(), &[2, 3]);
923 assert_eq!(result[[0, 0]], 11);
924 assert_eq!(result[[0, 1]], 22);
925 assert_eq!(result[[0, 2]], 33);
926 assert_eq!(result[[1, 0]], 14);
927 assert_eq!(result[[1, 1]], 25);
928 assert_eq!(result[[1, 2]], 36);
929
930 let result = broadcast_apply(a.view(), b.view(), |x, y| x * y).unwrap();
931 assert_eq!(result.shape(), &[2, 3]);
932 assert_eq!(result[[0, 0]], 10);
933 assert_eq!(result[[0, 1]], 40);
934 assert_eq!(result[[0, 2]], 90);
935 assert_eq!(result[[1, 0]], 40);
936 assert_eq!(result[[1, 1]], 100);
937 assert_eq!(result[[1, 2]], 180);
938 }
939}