scirs2_core/ndarray_ext/mod.rs
1//! Extended ndarray operations for scientific computing
2//!
3//! This module provides additional functionality for ndarray to support
4//! the advanced array operations necessary for a complete SciPy port.
5//! It implements core `NumPy`-like features that are not available in the
6//! base ndarray crate.
7
8/// Re-export essential ndarray types for convenience
9pub use ndarray::{
10 Array, ArrayBase, ArrayView, ArrayViewMut, Axis, Dim, Dimension, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6,
11 IxDyn, OwnedRepr, ShapeBuilder, SliceInfo, ViewRepr,
12};
13
14/// Advanced indexing operations (`NumPy`-like boolean masking, fancy indexing, etc.)
15pub mod indexing;
16
17/// Statistical functions for ndarray arrays (mean, median, variance, correlation, etc.)
18pub mod stats;
19
20/// Matrix operations (eye, diag, kron, etc.)
21pub mod matrix;
22
23/// Array manipulation operations (flip, roll, tile, repeat, etc.)
24pub mod manipulation;
25
26/// Reshape a 2D array to a new shape without copying data when possible
27///
28/// # Arguments
29///
30/// * `array` - The input array
31/// * `shape` - The new shape (rows, cols), which must be compatible with the original shape
32///
33/// # Returns
34///
35/// A reshaped array view of the input array if possible, or a new array otherwise
36///
37/// # Examples
38///
39/// ```
40/// use ndarray::array;
41/// use scirs2_core::ndarray_ext::reshape_2d;
42///
43/// let a = array![[1, 2], [3, 4]];
44/// let b = reshape_2d(a.view(), (4, 1)).unwrap();
45/// assert_eq!(b.shape(), &[4, 1]);
46/// assert_eq!(b[[0, 0]], 1);
47/// assert_eq!(b[[3, 0]], 4);
48/// ```
49#[allow(dead_code)]
50pub fn reshape_2d<T>(
51 array: ArrayView<T, Ix2>,
52 shape: (usize, usize),
53) -> Result<Array<T, Ix2>, &'static str>
54where
55 T: Clone + Default,
56{
57 let (rows, cols) = shape;
58 let total_elements = rows * cols;
59
60 // Check if the new shape is compatible with the original shape
61 if total_elements != array.len() {
62 return Err("New shape dimensions must match the total number of elements");
63 }
64
65 // Create a new array with the specified shape
66 let mut result = Array::<T, Ix2>::default(shape);
67
68 // Fill the result array with elements from the input array
69 let flat_iter = array.iter();
70 for (i, val) in flat_iter.enumerate() {
71 let r = i / cols;
72 let c = i % cols;
73 result[[r, c]] = val.clone();
74 }
75
76 Ok(result)
77}
78
79/// Stack 2D arrays along a given axis
80///
81/// # Arguments
82///
83/// * `arrays` - A slice of 2D arrays to stack
84/// * `axis` - The axis along which to stack (0 for rows, 1 for columns)
85///
86/// # Returns
87///
88/// A new array containing the stacked arrays
89///
90/// # Examples
91///
92/// ```
93/// use ndarray::{array, Axis};
94/// use scirs2_core::ndarray_ext::stack_2d;
95///
96/// let a = array![[1, 2], [3, 4]];
97/// let b = array![[5, 6], [7, 8]];
98/// let c = stack_2d(&[a.view(), b.view()], 0).unwrap();
99/// assert_eq!(c.shape(), &[4, 2]);
100/// ```
101#[allow(dead_code)]
102pub fn stack_2d<T>(arrays: &[ArrayView<T, Ix2>], axis: usize) -> Result<Array<T, Ix2>, &'static str>
103where
104 T: Clone + Default,
105{
106 if arrays.is_empty() {
107 return Err("No _arrays provided for stacking");
108 }
109
110 // Validate that all _arrays have the same shape
111 let firstshape = arrays[0].shape();
112 for array in arrays.iter().skip(1) {
113 if array.shape() != firstshape {
114 return Err("All _arrays must have the same shape for stacking");
115 }
116 }
117
118 let (rows, cols) = (firstshape[0], firstshape[1]);
119
120 // Calculate the new shape
121 let (new_rows, new_cols) = match axis {
122 0 => (rows * arrays.len(), cols), // Stack vertically
123 1 => (rows, cols * arrays.len()), // Stack horizontally
124 _ => return Err("Axis must be 0 or 1 for 2D _arrays"),
125 };
126
127 // Create a new array to hold the stacked result
128 let mut result = Array::<T, Ix2>::default((new_rows, new_cols));
129
130 // Copy data from the input _arrays to the result
131 match axis {
132 0 => {
133 // Stack vertically (along rows)
134 for (array_idx, array) in arrays.iter().enumerate() {
135 let start_row = array_idx * rows;
136 for r in 0..rows {
137 for c in 0..cols {
138 result[[start_row + r, c]] = array[[r, c]].clone();
139 }
140 }
141 }
142 }
143 1 => {
144 // Stack horizontally (along columns)
145 for (array_idx, array) in arrays.iter().enumerate() {
146 let start_col = array_idx * cols;
147 for r in 0..rows {
148 for c in 0..cols {
149 result[[r, start_col + c]] = array[[r, c]].clone();
150 }
151 }
152 }
153 }
154 _ => unreachable!(),
155 }
156
157 Ok(result)
158}
159
160/// Swap axes (transpose) of a 2D array
161///
162/// # Arguments
163///
164/// * `array` - The input 2D array
165///
166/// # Returns
167///
168/// A view of the input array with the axes swapped
169///
170/// # Examples
171///
172/// ```
173/// use ndarray::array;
174/// use scirs2_core::ndarray_ext::transpose_2d;
175///
176/// let a = array![[1, 2, 3], [4, 5, 6]];
177/// let b = transpose_2d(a.view());
178/// assert_eq!(b.shape(), &[3, 2]);
179/// assert_eq!(b[[0, 0]], 1);
180/// assert_eq!(b[[0, 1]], 4);
181/// ```
182#[allow(dead_code)]
183pub fn transpose_2d<T>(array: ArrayView<T, Ix2>) -> Array<T, Ix2>
184where
185 T: Clone,
186{
187 array.t().to_owned()
188}
189
190/// Split a 2D array into multiple sub-arrays along a given axis
191///
192/// # Arguments
193///
194/// * `array` - The input 2D array to split
195/// * `indices` - Indices where the array should be split
196/// * `axis` - The axis along which to split (0 for rows, 1 for columns)
197///
198/// # Returns
199///
200/// A vector of arrays resulting from the split
201///
202/// # Examples
203///
204/// ```
205/// use ndarray::array;
206/// use scirs2_core::ndarray_ext::split_2d;
207///
208/// let a = array![[1, 2, 3, 4], [5, 6, 7, 8]];
209/// let result = split_2d(a.view(), &[2], 1).unwrap();
210/// assert_eq!(result.len(), 2);
211/// assert_eq!(result[0].shape(), &[2, 2]);
212/// assert_eq!(result[1].shape(), &[2, 2]);
213/// ```
214#[allow(dead_code)]
215pub fn split_2d<T>(
216 array: ArrayView<T, Ix2>,
217 indices: &[usize],
218 axis: usize,
219) -> Result<Vec<Array<T, Ix2>>, &'static str>
220where
221 T: Clone + Default,
222{
223 if indices.is_empty() {
224 return Ok(vec![array.to_owned()]);
225 }
226
227 let (rows, cols) = (array.shape()[0], array.shape()[1]);
228 let axis_len = if axis == 0 { rows } else { cols };
229
230 // Validate indices
231 for &idx in indices {
232 if idx >= axis_len {
233 return Err("Split index out of bounds");
234 }
235 }
236
237 // Sort indices to ensure they're in ascending order
238 let mut sorted_indices = indices.to_vec();
239 sorted_indices.sort_unstable();
240
241 // Calculate the sub-array boundaries
242 let mut starts = vec![0];
243 starts.extend_from_slice(&sorted_indices);
244
245 let mut ends = sorted_indices.clone();
246 ends.push(axis_len);
247
248 // Create the split sub-arrays
249 let mut result = Vec::with_capacity(starts.len());
250
251 match axis {
252 0 => {
253 // Split along rows
254 for (&start, &end) in starts.iter().zip(ends.iter()) {
255 let sub_rows = end - start;
256 let mut sub_array = Array::<T, Ix2>::default((sub_rows, cols));
257
258 for r in 0..sub_rows {
259 for c in 0..cols {
260 sub_array[[r, c]] = array[[start + r, c]].clone();
261 }
262 }
263
264 result.push(sub_array);
265 }
266 }
267 1 => {
268 // Split along columns
269 for (&start, &end) in starts.iter().zip(ends.iter()) {
270 let sub_cols = end - start;
271 let mut sub_array = Array::<T, Ix2>::default((rows, sub_cols));
272
273 for r in 0..rows {
274 for c in 0..sub_cols {
275 sub_array[[r, c]] = array[[r, start + c]].clone();
276 }
277 }
278
279 result.push(sub_array);
280 }
281 }
282 _ => return Err("Axis must be 0 or 1 for 2D arrays"),
283 }
284
285 Ok(result)
286}
287
288/// Take elements from a 2D array along a given axis using indices from another array
289///
290/// # Arguments
291///
292/// * `array` - The input 2D array
293/// * `indices` - Array of indices to take
294/// * `axis` - The axis along which to take values (0 for rows, 1 for columns)
295///
296/// # Returns
297///
298/// An array of values at the specified indices along the given axis
299///
300/// # Examples
301///
302/// ```
303/// use ndarray::array;
304/// use scirs2_core::ndarray_ext::take_2d;
305///
306/// let a = array![[1, 2, 3], [4, 5, 6]];
307/// let indices = array![0, 2];
308/// let result = take_2d(a.view(), indices.view(), 1).unwrap();
309/// assert_eq!(result.shape(), &[2, 2]);
310/// assert_eq!(result[[0, 0]], 1);
311/// assert_eq!(result[[0, 1]], 3);
312/// ```
313#[allow(dead_code)]
314pub fn take_2d<T>(
315 array: ArrayView<T, Ix2>,
316 indices: ArrayView<usize, Ix1>,
317 axis: usize,
318) -> Result<Array<T, Ix2>, &'static str>
319where
320 T: Clone + Default,
321{
322 let (rows, cols) = (array.shape()[0], array.shape()[1]);
323 let axis_len = if axis == 0 { rows } else { cols };
324
325 // Check that indices are within bounds
326 for &idx in indices.iter() {
327 if idx >= axis_len {
328 return Err("Index out of bounds");
329 }
330 }
331
332 // Create the result array with the appropriate shape
333 let (result_rows, result_cols) = match axis {
334 0 => (indices.len(), cols),
335 1 => (rows, indices.len()),
336 _ => return Err("Axis must be 0 or 1 for 2D arrays"),
337 };
338
339 let mut result = Array::<T, Ix2>::default((result_rows, result_cols));
340
341 // Fill the result array
342 match axis {
343 0 => {
344 // Take along rows
345 for (i, &idx) in indices.iter().enumerate() {
346 for j in 0..cols {
347 result[[i, j]] = array[[idx, j]].clone();
348 }
349 }
350 }
351 1 => {
352 // Take along columns
353 for i in 0..rows {
354 for (j, &idx) in indices.iter().enumerate() {
355 result[[i, j]] = array[[i, idx]].clone();
356 }
357 }
358 }
359 _ => unreachable!(),
360 }
361
362 Ok(result)
363}
364
365/// Filter an array using a boolean mask
366///
367/// # Arguments
368///
369/// * `array` - The input array
370/// * `mask` - Boolean mask of the same shape as the array
371///
372/// # Returns
373///
374/// A 1D array containing the elements where the mask is true
375///
376/// # Examples
377///
378/// ```
379/// use ndarray::array;
380/// use scirs2_core::ndarray_ext::mask_select;
381///
382/// let a = array![[1, 2, 3], [4, 5, 6]];
383/// let mask = array![[true, false, true], [false, true, false]];
384/// let result = mask_select(a.view(), mask.view()).unwrap();
385/// assert_eq!(result.shape(), &[3]);
386/// assert_eq!(result[0], 1);
387/// assert_eq!(result[1], 3);
388/// assert_eq!(result[2], 5);
389/// ```
390#[allow(dead_code)]
391pub fn mask_select<T>(
392 array: ArrayView<T, Ix2>,
393 mask: ArrayView<bool, Ix2>,
394) -> Result<Array<T, Ix1>, &'static str>
395where
396 T: Clone + Default,
397{
398 // Check that the mask has the same shape as the array
399 if array.shape() != mask.shape() {
400 return Err("Mask shape must match array shape");
401 }
402
403 // Count the number of true values in the mask
404 let true_count = mask.iter().filter(|&&x| x).count();
405
406 // Create the result array
407 let mut result = Array::<T, Ix1>::default(true_count);
408
409 // Fill the result array with elements where the mask is true
410 let mut idx = 0;
411 for (val, &m) in array.iter().zip(mask.iter()) {
412 if m {
413 result[idx] = val.clone();
414 idx += 1;
415 }
416 }
417
418 Ok(result)
419}
420
421/// Index a 2D array with a list of index arrays (fancy indexing)
422///
423/// # Arguments
424///
425/// * `array` - The input 2D array
426/// * `row_indices` - Array of row indices
427/// * `col_indices` - Array of column indices
428///
429/// # Returns
430///
431/// A 1D array containing the elements at the specified indices
432///
433/// # Examples
434///
435/// ```
436/// use ndarray::array;
437/// use scirs2_core::ndarray_ext::fancy_index_2d;
438///
439/// let a = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]];
440/// let row_indices = array![0, 2];
441/// let col_indices = array![0, 1];
442/// let result = fancy_index_2d(a.view(), row_indices.view(), col_indices.view()).unwrap();
443/// assert_eq!(result.shape(), &[2]);
444/// assert_eq!(result[0], 1);
445/// assert_eq!(result[1], 8);
446/// ```
447#[allow(dead_code)]
448pub fn fancy_index_2d<T>(
449 array: ArrayView<T, Ix2>,
450 row_indices: ArrayView<usize, Ix1>,
451 col_indices: ArrayView<usize, Ix1>,
452) -> Result<Array<T, Ix1>, &'static str>
453where
454 T: Clone + Default,
455{
456 // Check that all index arrays have the same length
457 let result_size = row_indices.len();
458 if col_indices.len() != result_size {
459 return Err("Row and column index arrays must have the same length");
460 }
461
462 let (rows, cols) = (array.shape()[0], array.shape()[1]);
463
464 // Check that _indices are within bounds
465 for &idx in row_indices.iter() {
466 if idx >= rows {
467 return Err("Row index out of bounds");
468 }
469 }
470
471 for &idx in col_indices.iter() {
472 if idx >= cols {
473 return Err("Column index out of bounds");
474 }
475 }
476
477 // Create the result array
478 let mut result = Array::<T, Ix1>::default(result_size);
479
480 // Fill the result array
481 for i in 0..result_size {
482 let row = row_indices[i];
483 let col = col_indices[i];
484 result[i] = array[[row, col]].clone();
485 }
486
487 Ok(result)
488}
489
490/// Select elements from an array where a condition is true
491///
492/// # Arguments
493///
494/// * `array` - The input array
495/// * `condition` - A function that takes a reference to an element and returns a bool
496///
497/// # Returns
498///
499/// A 1D array containing the elements where the condition is true
500///
501/// # Examples
502///
503/// ```
504/// use ndarray::array;
505/// use scirs2_core::ndarray_ext::where_condition;
506///
507/// let a = array![[1, 2, 3], [4, 5, 6]];
508/// let result = where_condition(a.view(), |&x| x > 3).unwrap();
509/// assert_eq!(result.shape(), &[3]);
510/// assert_eq!(result[0], 4);
511/// assert_eq!(result[1], 5);
512/// assert_eq!(result[2], 6);
513/// ```
514#[allow(dead_code)]
515pub fn where_condition<T, F>(
516 array: ArrayView<T, Ix2>,
517 condition: F,
518) -> Result<Array<T, Ix1>, &'static str>
519where
520 T: Clone + Default,
521 F: Fn(&T) -> bool,
522{
523 // Build a boolean mask array based on the condition
524 let mask = array.map(condition);
525
526 // Use the mask_select function to select elements
527 mask_select(array, mask.view())
528}
529
530/// Check if two shapes are broadcast compatible
531///
532/// # Arguments
533///
534/// * `shape1` - First shape as a slice
535/// * `shape2` - Second shape as a slice
536///
537/// # Returns
538///
539/// `true` if the shapes are broadcast compatible, `false` otherwise
540///
541/// # Examples
542///
543/// ```
544/// use scirs2_core::ndarray_ext::is_broadcast_compatible;
545///
546/// assert!(is_broadcast_compatible(&[2, 3], &[3]));
547/// // This example has dimensions that don't match (5 vs 3 in dimension 0)
548/// // and aren't broadcasting compatible (neither is 1)
549/// assert!(!is_broadcast_compatible(&[5, 1, 4], &[3, 1, 1]));
550/// assert!(!is_broadcast_compatible(&[2, 3], &[4]));
551/// ```
552#[allow(dead_code)]
553pub fn is_broadcast_compatible(shape1: &[usize], shape2: &[usize]) -> bool {
554 // Align shapes to have the same dimensionality by prepending with 1s
555 let max_dim = shape1.len().max(shape2.len());
556
557 // Fill in 1s for missing dimensions
558 let get_dim = |shape: &[usize], i: usize| -> usize {
559 let offset = max_dim - shape.len();
560 if i < offset {
561 1 // Implicit dimension of size 1
562 } else {
563 shape[i - offset]
564 }
565 };
566
567 // Check broadcasting rules for each dimension
568 for i in 0..max_dim {
569 let dim1 = get_dim(shape1, i);
570 let dim2 = get_dim(shape2, i);
571
572 // Dimensions must either be the same or one of them must be 1
573 if dim1 != dim2 && dim1 != 1 && dim2 != 1 {
574 return false;
575 }
576 }
577
578 true
579}
580
581/// Calculate the broadcasted shape from two input shapes
582///
583/// # Arguments
584///
585/// * `shape1` - First shape as a slice
586/// * `shape2` - Second shape as a slice
587///
588/// # Returns
589///
590/// The broadcasted shape as a `Vec<usize>`, or `None` if the shapes are incompatible
591#[allow(dead_code)]
592pub fn broadcastshape(shape1: &[usize], shape2: &[usize]) -> Option<Vec<usize>> {
593 if !is_broadcast_compatible(shape1, shape2) {
594 return None;
595 }
596
597 // Align shapes to have the same dimensionality
598 let max_dim = shape1.len().max(shape2.len());
599 let mut result = Vec::with_capacity(max_dim);
600
601 // Fill in 1s for missing dimensions
602 let get_dim = |shape: &[usize], i: usize| -> usize {
603 let offset = max_dim - shape.len();
604 if i < offset {
605 1 // Implicit dimension of size 1
606 } else {
607 shape[i - offset]
608 }
609 };
610
611 // Calculate the broadcasted shape
612 for i in 0..max_dim {
613 let dim1 = get_dim(shape1, i);
614 let dim2 = get_dim(shape2, i);
615
616 // The broadcasted dimension is the maximum of the two
617 result.push(dim1.max(dim2));
618 }
619
620 Some(result)
621}
622
623/// Broadcast a 1D array to a 2D shape by repeating it along the specified axis
624///
625/// # Arguments
626///
627/// * `array` - The input 1D array
628/// * `repeats` - Number of times to repeat the array
629/// * `axis` - Axis along which to repeat (0 for rows, 1 for columns)
630///
631/// # Returns
632///
633/// A 2D array with the input array repeated along the specified axis
634///
635/// # Examples
636///
637/// ```
638/// use ndarray::array;
639/// use scirs2_core::ndarray_ext::broadcast_1d_to_2d;
640///
641/// let a = array![1, 2, 3];
642/// let b = broadcast_1d_to_2d(a.view(), 2, 0).unwrap();
643/// assert_eq!(b.shape(), &[2, 3]);
644/// assert_eq!(b[[0, 0]], 1);
645/// assert_eq!(b[[1, 0]], 1);
646/// ```
647#[allow(dead_code)]
648pub fn broadcast_1d_to_2d<T>(
649 array: ArrayView<T, Ix1>,
650 repeats: usize,
651 axis: usize,
652) -> Result<Array<T, Ix2>, &'static str>
653where
654 T: Clone + Default,
655{
656 let len = array.len();
657
658 // Create the result array with the appropriate shape
659 let (rows, cols) = match axis {
660 0 => (repeats, len), // Broadcast along rows
661 1 => (len, repeats), // Broadcast along columns
662 _ => return Err("Axis must be 0 or 1"),
663 };
664
665 let mut result = Array::<T, Ix2>::default((rows, cols));
666
667 // Fill the result array
668 match axis {
669 0 => {
670 // Broadcast along rows
671 for i in 0..repeats {
672 for j in 0..len {
673 result[[i, j]] = array[j].clone();
674 }
675 }
676 }
677 1 => {
678 // Broadcast along columns
679 for i in 0..len {
680 for j in 0..repeats {
681 result[[i, j]] = array[i].clone();
682 }
683 }
684 }
685 _ => unreachable!(),
686 }
687
688 Ok(result)
689}
690
691/// Apply an element-wise binary operation to two arrays with broadcasting
692///
693/// # Arguments
694///
695/// * `a` - First array (2D)
696/// * `b` - Second array (can be 1D or 2D)
697/// * `op` - Binary operation to apply to each pair of elements
698///
699/// # Returns
700///
701/// A 2D array containing the result of the operation applied element-wise
702///
703/// # Examples
704///
705/// ```
706/// use ndarray::array;
707/// use scirs2_core::ndarray_ext::broadcast_apply;
708///
709/// let a = array![[1, 2, 3], [4, 5, 6]];
710/// let b = array![10, 20, 30];
711/// let result = broadcast_apply(a.view(), b.view(), |x, y| x + y).unwrap();
712/// assert_eq!(result.shape(), &[2, 3]);
713/// assert_eq!(result[[0, 0]], 11);
714/// assert_eq!(result[[1, 2]], 36);
715/// ```
716#[allow(dead_code)]
717pub fn broadcast_apply<T, R, F>(
718 a: ArrayView<T, Ix2>,
719 b: ArrayView<T, Ix1>,
720 op: F,
721) -> Result<Array<R, Ix2>, &'static str>
722where
723 T: Clone + Default,
724 R: Clone + Default,
725 F: Fn(&T, &T) -> R,
726{
727 let (a_rows, a_cols) = (a.shape()[0], a.shape()[1]);
728 let b_len = b.len();
729
730 // Check that the arrays are broadcast compatible
731 if a_cols != b_len {
732 return Err("Arrays are not broadcast compatible");
733 }
734
735 // Create the result array
736 let mut result = Array::<R, Ix2>::default((a_rows, a_cols));
737
738 // Apply the operation element-wise
739 for i in 0..a_rows {
740 for j in 0..a_cols {
741 result[[i, j]] = op(&a[[i, j]], &b[j]);
742 }
743 }
744
745 Ok(result)
746}
747
748#[cfg(test)]
749mod tests {
750 use super::*;
751 use ndarray::array;
752
753 #[test]
754 fn test_reshape_2d() {
755 let a = array![[1, 2], [3, 4]];
756 let b = reshape_2d(a.view(), (4, 1)).unwrap();
757 assert_eq!(b.shape(), &[4, 1]);
758 assert_eq!(b[[0, 0]], 1);
759 assert_eq!(b[[1, 0]], 2);
760 assert_eq!(b[[2, 0]], 3);
761 assert_eq!(b[[3, 0]], 4);
762
763 // Test invalid shape
764 let result = reshape_2d(a.view(), (3, 1));
765 assert!(result.is_err());
766 }
767
768 #[test]
769 fn test_stack_2d() {
770 let a = array![[1, 2], [3, 4]];
771 let b = array![[5, 6], [7, 8]];
772
773 // Stack vertically (along axis 0)
774 let c = stack_2d(&[a.view(), b.view()], 0).unwrap();
775 assert_eq!(c.shape(), &[4, 2]);
776 assert_eq!(c[[0, 0]], 1);
777 assert_eq!(c[[1, 0]], 3);
778 assert_eq!(c[[2, 0]], 5);
779 assert_eq!(c[[3, 0]], 7);
780
781 // Stack horizontally (along axis 1)
782 let d = stack_2d(&[a.view(), b.view()], 1).unwrap();
783 assert_eq!(d.shape(), &[2, 4]);
784 assert_eq!(d[[0, 0]], 1);
785 assert_eq!(d[[0, 1]], 2);
786 assert_eq!(d[[0, 2]], 5);
787 assert_eq!(d[[0, 3]], 6);
788 }
789
790 #[test]
791 fn test_transpose_2d() {
792 let a = array![[1, 2, 3], [4, 5, 6]];
793 let b = transpose_2d(a.view());
794 assert_eq!(b.shape(), &[3, 2]);
795 assert_eq!(b[[0, 0]], 1);
796 assert_eq!(b[[0, 1]], 4);
797 assert_eq!(b[[1, 0]], 2);
798 assert_eq!(b[[2, 1]], 6);
799 }
800
801 #[test]
802 fn test_split_2d() {
803 let a = array![[1, 2, 3, 4], [5, 6, 7, 8]];
804
805 // Split along columns at index 2
806 let result = split_2d(a.view(), &[2], 1).unwrap();
807 assert_eq!(result.len(), 2);
808 assert_eq!(result[0].shape(), &[2, 2]);
809 assert_eq!(result[0][[0, 0]], 1);
810 assert_eq!(result[0][[0, 1]], 2);
811 assert_eq!(result[0][[1, 0]], 5);
812 assert_eq!(result[0][[1, 1]], 6);
813 assert_eq!(result[1].shape(), &[2, 2]);
814 assert_eq!(result[1][[0, 0]], 3);
815 assert_eq!(result[1][[0, 1]], 4);
816 assert_eq!(result[1][[1, 0]], 7);
817 assert_eq!(result[1][[1, 1]], 8);
818
819 // Split along rows at index 1
820 let result = split_2d(a.view(), &[1], 0).unwrap();
821 assert_eq!(result.len(), 2);
822 assert_eq!(result[0].shape(), &[1, 4]);
823 assert_eq!(result[1].shape(), &[1, 4]);
824 }
825
826 #[test]
827 fn test_take_2d() {
828 let a = array![[1, 2, 3], [4, 5, 6]];
829 let indices = array![0, 2];
830
831 // Take along columns
832 let result = take_2d(a.view(), indices.view(), 1).unwrap();
833 assert_eq!(result.shape(), &[2, 2]);
834 assert_eq!(result[[0, 0]], 1);
835 assert_eq!(result[[0, 1]], 3);
836 assert_eq!(result[[1, 0]], 4);
837 assert_eq!(result[[1, 1]], 6);
838 }
839
840 #[test]
841 fn test_mask_select() {
842 let a = array![[1, 2, 3], [4, 5, 6]];
843 let mask = array![[true, false, true], [false, true, false]];
844
845 let result = mask_select(a.view(), mask.view()).unwrap();
846 assert_eq!(result.shape(), &[3]);
847 assert_eq!(result[0], 1);
848 assert_eq!(result[1], 3);
849 assert_eq!(result[2], 5);
850 }
851
852 #[test]
853 fn test_fancy_index_2d() {
854 let a = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]];
855 let row_indices = array![0, 2];
856 let col_indices = array![0, 1];
857
858 let result = fancy_index_2d(a.view(), row_indices.view(), col_indices.view()).unwrap();
859 assert_eq!(result.shape(), &[2]);
860 assert_eq!(result[0], 1);
861 assert_eq!(result[1], 8);
862 }
863
864 #[test]
865 fn test_where_condition() {
866 let a = array![[1, 2, 3], [4, 5, 6]];
867 let result = where_condition(a.view(), |&x| x > 3).unwrap();
868 assert_eq!(result.shape(), &[3]);
869 assert_eq!(result[0], 4);
870 assert_eq!(result[1], 5);
871 assert_eq!(result[2], 6);
872 }
873
874 #[test]
875 fn test_broadcast_1d_to_2d() {
876 let a = array![1, 2, 3];
877
878 // Broadcast along rows (axis 0)
879 let b = broadcast_1d_to_2d(a.view(), 2, 0).unwrap();
880 assert_eq!(b.shape(), &[2, 3]);
881 assert_eq!(b[[0, 0]], 1);
882 assert_eq!(b[[0, 1]], 2);
883 assert_eq!(b[[1, 0]], 1);
884 assert_eq!(b[[1, 2]], 3);
885
886 // Broadcast along columns (axis 1)
887 let c = broadcast_1d_to_2d(a.view(), 2, 1).unwrap();
888 assert_eq!(c.shape(), &[3, 2]);
889 assert_eq!(c[[0, 0]], 1);
890 assert_eq!(c[[0, 1]], 1);
891 assert_eq!(c[[1, 0]], 2);
892 assert_eq!(c[[2, 1]], 3);
893 }
894
895 #[test]
896 fn test_broadcast_apply() {
897 let a = array![[1, 2, 3], [4, 5, 6]];
898 let b = array![10, 20, 30];
899
900 let result = broadcast_apply(a.view(), b.view(), |x, y| x + y).unwrap();
901 assert_eq!(result.shape(), &[2, 3]);
902 assert_eq!(result[[0, 0]], 11);
903 assert_eq!(result[[0, 1]], 22);
904 assert_eq!(result[[0, 2]], 33);
905 assert_eq!(result[[1, 0]], 14);
906 assert_eq!(result[[1, 1]], 25);
907 assert_eq!(result[[1, 2]], 36);
908
909 let result = broadcast_apply(a.view(), b.view(), |x, y| x * y).unwrap();
910 assert_eq!(result.shape(), &[2, 3]);
911 assert_eq!(result[[0, 0]], 10);
912 assert_eq!(result[[0, 1]], 40);
913 assert_eq!(result[[0, 2]], 90);
914 assert_eq!(result[[1, 0]], 40);
915 assert_eq!(result[[1, 1]], 100);
916 assert_eq!(result[[1, 2]], 180);
917 }
918}