scirs2_core/ndarray_ext/indexing.rs
1//! Advanced indexing operations for ndarray
2//!
3//! This module provides enhanced indexing capabilities including boolean
4//! masking, fancy indexing, and advanced slicing operations similar to
5//! `NumPy`'s advanced indexing functionality.
6
7use ndarray::{Array, ArrayView, Ix1, Ix2};
8
9/// Result type for coordinating indices
10pub type IndicesResult = Result<(Array<usize, Ix1>, Array<usize, Ix1>), &'static str>;
11
12/// Take elements from a 2D array along a given axis using indices from another array
13///
14/// # Arguments
15///
16/// * `array` - The input 2D array
17/// * `indices` - Array of indices to take
18/// * `axis` - The axis along which to take values (0 for rows, 1 for columns)
19///
20/// # Returns
21///
22/// A 2D array of values at the specified indices along the given axis
23///
24/// # Examples
25///
26/// ```
27/// use ndarray::array;
28/// use scirs2_core::ndarray_ext::indexing::take_2d;
29///
30/// let a = array![[1, 2, 3], [4, 5, 6]];
31/// let indices = array![0, 2];
32/// let result = take_2d(a.view(), indices.view(), 1).unwrap();
33/// assert_eq!(result.shape(), &[2, 2]);
34/// assert_eq!(result[[0, 0]], 1);
35/// assert_eq!(result[[0, 1]], 3);
36/// ```
37#[allow(dead_code)]
38pub fn take_2d<T>(
39 array: ArrayView<T, Ix2>,
40 indices: ArrayView<usize, Ix1>,
41 axis: usize,
42) -> Result<Array<T, Ix2>, &'static str>
43where
44 T: Clone + Default,
45{
46 let (rows, cols) = (array.shape()[0], array.shape()[1]);
47 let axis_len = if axis == 0 { rows } else { cols };
48
49 // Check that indices are within bounds
50 for &idx in indices.iter() {
51 if idx >= axis_len {
52 return Err("Index out of bounds");
53 }
54 }
55
56 // Create the result array with the appropriate shape
57 let (result_rows, result_cols) = match axis {
58 0 => (indices.len(), cols),
59 1 => (rows, indices.len()),
60 _ => return Err("Axis must be 0 or 1 for 2D arrays"),
61 };
62
63 let mut result = Array::<T, Ix2>::default((result_rows, result_cols));
64
65 // Fill the result array
66 match axis {
67 0 => {
68 // Take along rows
69 for (i, &idx) in indices.iter().enumerate() {
70 for j in 0..cols {
71 result[[i, j]] = array[[idx, j]].clone();
72 }
73 }
74 }
75 1 => {
76 // Take along columns
77 for i in 0..rows {
78 for (j, &idx) in indices.iter().enumerate() {
79 result[[i, j]] = array[[i, idx]].clone();
80 }
81 }
82 }
83 _ => unreachable!(),
84 }
85
86 Ok(result)
87}
88
89/// Boolean mask indexing for 2D arrays
90///
91/// # Arguments
92///
93/// * `array` - The input 2D array
94/// * `mask` - Boolean mask of the same shape as the array
95///
96/// # Returns
97///
98/// A 1D array containing the elements where the mask is true
99///
100/// # Examples
101///
102/// ```
103/// use ndarray::array;
104/// use scirs2_core::ndarray_ext::indexing::boolean_mask_2d;
105///
106/// let a = array![[1, 2, 3], [4, 5, 6]];
107/// let mask = array![[true, false, true], [false, true, false]];
108/// let result = boolean_mask_2d(a.view(), mask.view()).unwrap();
109/// assert_eq!(result.len(), 3);
110/// assert_eq!(result[0], 1);
111/// assert_eq!(result[1], 3);
112/// assert_eq!(result[2], 5);
113/// ```
114#[allow(dead_code)]
115pub fn boolean_mask_2d<T>(
116 array: ArrayView<T, Ix2>,
117 mask: ArrayView<bool, Ix2>,
118) -> Result<Array<T, Ix1>, &'static str>
119where
120 T: Clone + Default,
121{
122 // Check that the mask has the same shape as the array
123 if array.shape() != mask.shape() {
124 return Err("Mask shape must match array shape");
125 }
126
127 // Count the number of true values in the mask
128 let true_count = mask.iter().filter(|&&x| x).count();
129
130 // Create the result array
131 let mut result = Array::<T, Ix1>::default(true_count);
132
133 // Fill the result array with elements where the mask is true
134 let mut idx = 0;
135 for (val, &m) in array.iter().zip(mask.iter()) {
136 if m {
137 result[idx] = val.clone();
138 idx += 1;
139 }
140 }
141
142 Ok(result)
143}
144
145/// Boolean mask indexing for 1D arrays
146///
147/// # Arguments
148///
149/// * `array` - The input 1D array
150/// * `mask` - Boolean mask of the same shape as the array
151///
152/// # Returns
153///
154/// A 1D array containing the elements where the mask is true
155///
156/// # Examples
157///
158/// ```
159/// use ndarray::array;
160/// use scirs2_core::ndarray_ext::indexing::boolean_mask_1d;
161///
162/// let a = array![1, 2, 3, 4, 5];
163/// let mask = array![true, false, true, false, true];
164/// let result = boolean_mask_1d(a.view(), mask.view()).unwrap();
165/// assert_eq!(result.len(), 3);
166/// assert_eq!(result[0], 1);
167/// assert_eq!(result[1], 3);
168/// assert_eq!(result[2], 5);
169/// ```
170#[allow(dead_code)]
171pub fn boolean_mask_1d<T>(
172 array: ArrayView<T, Ix1>,
173 mask: ArrayView<bool, Ix1>,
174) -> Result<Array<T, Ix1>, &'static str>
175where
176 T: Clone + Default,
177{
178 // Check that the mask has the same shape as the array
179 if array.shape() != mask.shape() {
180 return Err("Mask shape must match array shape");
181 }
182
183 // Count the number of true values in the mask
184 let true_count = mask.iter().filter(|&&x| x).count();
185
186 // Create the result array
187 let mut result = Array::<T, Ix1>::default(true_count);
188
189 // Fill the result array with elements where the mask is true
190 let mut idx = 0;
191 for (val, &m) in array.iter().zip(mask.iter()) {
192 if m {
193 result[idx] = val.clone();
194 idx += 1;
195 }
196 }
197
198 Ok(result)
199}
200
201/// Indexed slicing for 1D arrays
202///
203/// # Arguments
204///
205/// * `array` - The input 1D array
206/// * `indices` - Array of indices to extract
207///
208/// # Returns
209///
210/// A 1D array containing the elements at the specified indices
211///
212/// # Examples
213///
214/// ```
215/// use ndarray::array;
216/// use scirs2_core::ndarray_ext::indexing::take_1d;
217///
218/// let a = array![10, 20, 30, 40, 50];
219/// let indices = array![0, 2, 4];
220/// let result = take_1d(a.view(), indices.view()).unwrap();
221/// assert_eq!(result.len(), 3);
222/// assert_eq!(result[0], 10);
223/// assert_eq!(result[1], 30);
224/// assert_eq!(result[2], 50);
225/// ```
226#[allow(dead_code)]
227pub fn take_1d<T>(
228 array: ArrayView<T, Ix1>,
229 indices: ArrayView<usize, Ix1>,
230) -> Result<Array<T, Ix1>, &'static str>
231where
232 T: Clone + Default,
233{
234 let len = array.len();
235
236 // Verify that indices are in bounds
237 for &idx in indices.iter() {
238 if idx >= len {
239 return Err("Index out of bounds");
240 }
241 }
242
243 // Create the result array
244 let mut result = Array::<T, Ix1>::default(indices.len());
245
246 // Extract the elements at the specified indices
247 for (i, &idx) in indices.iter().enumerate() {
248 result[i] = array[idx].clone();
249 }
250
251 Ok(result)
252}
253
254/// Fancy indexing for 2D arrays with pairs of index arrays
255///
256/// # Arguments
257///
258/// * `array` - The input 2D array
259/// * `row_indices` - Array of row indices
260/// * `col_indices` - Array of column indices
261///
262/// # Returns
263///
264/// A 1D array containing the elements at the specified indices
265///
266/// # Examples
267///
268/// ```
269/// use ndarray::array;
270/// use scirs2_core::ndarray_ext::indexing::fancy_index_2d;
271///
272/// let a = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]];
273/// let row_indices = array![0, 1, 2];
274/// let col_indices = array![0, 1, 2];
275/// let result = fancy_index_2d(a.view(), row_indices.view(), col_indices.view()).unwrap();
276/// assert_eq!(result.len(), 3);
277/// assert_eq!(result[0], 1);
278/// assert_eq!(result[1], 5);
279/// assert_eq!(result[2], 9);
280/// ```
281#[allow(dead_code)]
282pub fn fancy_index_2d<T>(
283 array: ArrayView<T, Ix2>,
284 row_indices: ArrayView<usize, Ix1>,
285 col_indices: ArrayView<usize, Ix1>,
286) -> Result<Array<T, Ix1>, &'static str>
287where
288 T: Clone + Default,
289{
290 // Check that all index arrays have the same length
291 let result_size = row_indices.len();
292 if col_indices.len() != result_size {
293 return Err("Row and column index arrays must have the same length");
294 }
295
296 let (rows, cols) = (array.shape()[0], array.shape()[1]);
297
298 // Check that _indices are within bounds
299 for &idx in row_indices.iter() {
300 if idx >= rows {
301 return Err("Row index out of bounds");
302 }
303 }
304
305 for &idx in col_indices.iter() {
306 if idx >= cols {
307 return Err("Column index out of bounds");
308 }
309 }
310
311 // Create the result array
312 let mut result = Array::<T, Ix1>::default(result_size);
313
314 // Fill the result array
315 for i in 0..result_size {
316 let row = row_indices[i];
317 let col = col_indices[i];
318 result[i] = array[[row, col]].clone();
319 }
320
321 Ok(result)
322}
323
324/// Extract a diagonal or a sub-diagonal from a 2D array
325///
326/// # Arguments
327///
328/// * `array` - The input 2D array
329/// * `offset` - Offset from the main diagonal (0 for main diagonal, positive for above, negative for below)
330///
331/// # Returns
332///
333/// A 1D array containing the diagonal elements
334///
335/// # Examples
336///
337/// ```
338/// use ndarray::array;
339/// use scirs2_core::ndarray_ext::indexing::diagonal;
340///
341/// let a = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]];
342///
343/// // Main diagonal
344/// let main_diag = diagonal(a.view(), 0).unwrap();
345/// assert_eq!(main_diag.len(), 3);
346/// assert_eq!(main_diag[0], 1);
347/// assert_eq!(main_diag[1], 5);
348/// assert_eq!(main_diag[2], 9);
349///
350/// // Upper diagonal
351/// let upper_diag = diagonal(a.view(), 1).unwrap();
352/// assert_eq!(upper_diag.len(), 2);
353/// assert_eq!(upper_diag[0], 2);
354/// assert_eq!(upper_diag[1], 6);
355///
356/// // Lower diagonal
357/// let lower_diag = diagonal(a.view(), -1).unwrap();
358/// assert_eq!(lower_diag.len(), 2);
359/// assert_eq!(lower_diag[0], 4);
360/// assert_eq!(lower_diag[1], 8);
361/// ```
362#[allow(dead_code)]
363pub fn diagonal<T>(array: ArrayView<T, Ix2>, offset: isize) -> Result<Array<T, Ix1>, &'static str>
364where
365 T: Clone + Default,
366{
367 let (rows, cols) = (array.shape()[0], array.shape()[1]);
368
369 // Calculate the length of the diagonal
370 let diag_len = if offset >= 0 {
371 std::cmp::min(rows, cols.saturating_sub(offset as usize))
372 } else {
373 std::cmp::min(cols, rows.saturating_sub((-offset) as usize))
374 };
375
376 if diag_len == 0 {
377 return Err("No diagonal elements for the given offset");
378 }
379
380 // Create the result _array
381 let mut result = Array::<T, Ix1>::default(diag_len);
382
383 // Extract the diagonal elements
384 for i in 0..diag_len {
385 let row = if offset < 0 {
386 i + (-offset) as usize
387 } else {
388 i
389 };
390
391 let col = if offset > 0 { i + offset as usize } else { i };
392
393 result[i] = array[[row, col]].clone();
394 }
395
396 Ok(result)
397}
398
399/// Where function - select elements based on a condition for 1D arrays
400///
401/// # Arguments
402///
403/// * `array` - The input 1D array
404/// * `condition` - A function that takes a reference to an element and returns a bool
405///
406/// # Returns
407///
408/// A 1D array containing the elements where the condition is true
409///
410/// # Examples
411///
412/// ```
413/// use ndarray::array;
414/// use scirs2_core::ndarray_ext::indexing::where_1d;
415///
416/// let a = array![1, 2, 3, 4, 5];
417/// let result = where_1d(a.view(), |&x| x > 3).unwrap();
418/// assert_eq!(result.len(), 2);
419/// assert_eq!(result[0], 4);
420/// assert_eq!(result[1], 5);
421/// ```
422#[allow(dead_code)]
423pub fn where_1d<T, F>(array: ArrayView<T, Ix1>, condition: F) -> Result<Array<T, Ix1>, &'static str>
424where
425 T: Clone + Default,
426 F: Fn(&T) -> bool,
427{
428 // Build a boolean mask _array based on the condition
429 let mask = array.map(condition);
430
431 // Use the boolean_mask_1d function to select elements
432 boolean_mask_1d(array, mask.view())
433}
434
435/// Where function - select elements based on a condition for 2D arrays
436///
437/// # Arguments
438///
439/// * `array` - The input 2D array
440/// * `condition` - A function that takes a reference to an element and returns a bool
441///
442/// # Returns
443///
444/// A 1D array containing the elements where the condition is true
445///
446/// # Examples
447///
448/// ```
449/// use ndarray::array;
450/// use scirs2_core::ndarray_ext::indexing::where_2d;
451///
452/// let a = array![[1, 2, 3], [4, 5, 6]];
453/// let result = where_2d(a.view(), |&x| x > 3).unwrap();
454/// assert_eq!(result.len(), 3);
455/// assert_eq!(result[0], 4);
456/// assert_eq!(result[1], 5);
457/// assert_eq!(result[2], 6);
458/// ```
459#[allow(dead_code)]
460pub fn where_2d<T, F>(array: ArrayView<T, Ix2>, condition: F) -> Result<Array<T, Ix1>, &'static str>
461where
462 T: Clone + Default,
463 F: Fn(&T) -> bool,
464{
465 // Build a boolean mask _array based on the condition
466 let mask = array.map(condition);
467
468 // Use the boolean_mask_2d function to select elements
469 boolean_mask_2d(array, mask.view())
470}
471
472/// Extract indices where a 1D array meets a condition
473///
474/// # Arguments
475///
476/// * `array` - The input 1D array
477/// * `condition` - A function that takes a reference to an element and returns a bool
478///
479/// # Returns
480///
481/// A 1D array of indices where the condition is true
482///
483/// # Examples
484///
485/// ```
486/// use ndarray::array;
487/// use scirs2_core::ndarray_ext::indexing::indices_where_1d;
488///
489/// let a = array![10, 20, 30, 40, 50];
490/// let result = indices_where_1d(a.view(), |&x| x > 30).unwrap();
491/// assert_eq!(result.len(), 2);
492/// assert_eq!(result[0], 3);
493/// assert_eq!(result[1], 4);
494/// ```
495#[allow(dead_code)]
496pub fn indices_where_1d<T, F>(
497 array: ArrayView<T, Ix1>,
498 condition: F,
499) -> Result<Array<usize, Ix1>, &'static str>
500where
501 T: Clone,
502 F: Fn(&T) -> bool,
503{
504 // Build a vector of indices where the condition is true
505 let mut indices = Vec::new();
506
507 for (i, val) in array.iter().enumerate() {
508 if condition(val) {
509 indices.push(i);
510 }
511 }
512
513 // Convert the vector to an ndarray Array
514 Ok(Array::from_vec(indices))
515}
516
517/// Extract indices where a 2D array meets a condition
518///
519/// # Arguments
520///
521/// * `array` - The input 2D array
522/// * `condition` - A function that takes a reference to an element and returns a bool
523///
524/// # Returns
525///
526/// A tuple of two 1D arrays (row_indices, col_indices) where the condition is true
527///
528/// # Examples
529///
530/// ```
531/// use ndarray::array;
532/// use scirs2_core::ndarray_ext::indexing::indices_where_2d;
533///
534/// let a = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]];
535/// let (rows, cols) = indices_where_2d(a.view(), |&x| x > 5).unwrap();
536/// assert_eq!(rows.len(), 4);
537/// assert_eq!(cols.len(), 4);
538/// // The indices correspond to elements: 6, 7, 8, 9
539/// ```
540#[allow(dead_code)]
541pub fn indices_where_2d<T, F>(array: ArrayView<T, Ix2>, condition: F) -> IndicesResult
542where
543 T: Clone,
544 F: Fn(&T) -> bool,
545{
546 let (rows, cols) = (array.shape()[0], array.shape()[1]);
547
548 // Build vectors of row and column indices where the condition is true
549 let mut row_indices = Vec::new();
550 let mut col_indices = Vec::new();
551
552 for r in 0..rows {
553 for c in 0..cols {
554 if condition(&array[[r, c]]) {
555 row_indices.push(r);
556 col_indices.push(c);
557 }
558 }
559 }
560
561 // Convert the vectors to ndarray Arrays
562 Ok((Array::from_vec(row_indices), Array::from_vec(col_indices)))
563}
564
565/// Return elements from a 2D array along an axis at specified indices
566///
567/// # Arguments
568///
569/// * `array` - The input 2D array
570/// * `indices` - Indices to take along the specified axis
571/// * `axis` - The axis along which to take values (0 for rows, 1 for columns)
572///
573/// # Returns
574///
575/// A 2D array with selected slices from the original array
576///
577/// # Examples
578///
579/// ```
580/// use ndarray::array;
581/// use scirs2_core::ndarray_ext::indexing::take_along_axis;
582///
583/// let a = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]];
584/// let indices = array![0, 2];
585///
586/// // Take rows 0 and 2
587/// let result = take_along_axis(a.view(), indices.view(), 0).unwrap();
588/// assert_eq!(result.shape(), &[2, 3]);
589/// assert_eq!(result[[0, 0]], 1);
590/// assert_eq!(result[[0, 1]], 2);
591/// assert_eq!(result[[0, 2]], 3);
592/// assert_eq!(result[[1, 0]], 7);
593/// assert_eq!(result[[1, 1]], 8);
594/// assert_eq!(result[[1, 2]], 9);
595/// ```
596#[allow(dead_code)]
597pub fn take_along_axis<T>(
598 array: ArrayView<T, Ix2>,
599 indices: ArrayView<usize, Ix1>,
600 axis: usize,
601) -> Result<Array<T, Ix2>, &'static str>
602where
603 T: Clone + Default,
604{
605 take_2d(array, indices, axis)
606}
607
608#[cfg(test)]
609mod tests {
610 use super::*;
611 use ndarray::array;
612
613 #[test]
614 fn test_boolean_mask_1d() {
615 let a = array![1, 2, 3, 4, 5];
616 let mask = array![true, false, true, false, true];
617
618 let result = boolean_mask_1d(a.view(), mask.view()).unwrap();
619 assert_eq!(result.len(), 3);
620 assert_eq!(result[0], 1);
621 assert_eq!(result[1], 3);
622 assert_eq!(result[2], 5);
623 }
624
625 #[test]
626 fn test_boolean_mask_2d() {
627 let a = array![[1, 2, 3], [4, 5, 6]];
628 let mask = array![[true, false, true], [false, true, false]];
629
630 let result = boolean_mask_2d(a.view(), mask.view()).unwrap();
631 assert_eq!(result.len(), 3);
632 assert_eq!(result[0], 1);
633 assert_eq!(result[1], 3);
634 assert_eq!(result[2], 5);
635 }
636
637 #[test]
638 fn test_take_1d() {
639 let a = array![10, 20, 30, 40, 50];
640 let indices = array![0, 2, 4];
641
642 let result = take_1d(a.view(), indices.view()).unwrap();
643 assert_eq!(result.len(), 3);
644 assert_eq!(result[0], 10);
645 assert_eq!(result[1], 30);
646 assert_eq!(result[2], 50);
647 }
648
649 #[test]
650 fn test_take_2d() {
651 let a = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]];
652 let indices = array![0, 2];
653
654 // Take along axis 0 (rows)
655 let result = take_2d(a.view(), indices.view(), 0).unwrap();
656 assert_eq!(result.shape(), &[2, 3]);
657 assert_eq!(result[[0, 0]], 1);
658 assert_eq!(result[[0, 1]], 2);
659 assert_eq!(result[[0, 2]], 3);
660 assert_eq!(result[[1, 0]], 7);
661 assert_eq!(result[[1, 1]], 8);
662 assert_eq!(result[[1, 2]], 9);
663
664 // Take along axis 1 (columns)
665 let result = take_2d(a.view(), indices.view(), 1).unwrap();
666 assert_eq!(result.shape(), &[3, 2]);
667 assert_eq!(result[[0, 0]], 1);
668 assert_eq!(result[[0, 1]], 3);
669 assert_eq!(result[[1, 0]], 4);
670 assert_eq!(result[[1, 1]], 6);
671 assert_eq!(result[[2, 0]], 7);
672 assert_eq!(result[[2, 1]], 9);
673 }
674
675 #[test]
676 fn test_fancy_index_2d() {
677 let a = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]];
678 let row_indices = array![0, 2];
679 let col_indices = array![0, 1];
680
681 let result = fancy_index_2d(a.view(), row_indices.view(), col_indices.view()).unwrap();
682 assert_eq!(result.len(), 2);
683 assert_eq!(result[0], 1);
684 assert_eq!(result[1], 8);
685 }
686
687 #[test]
688 fn test_diagonal() {
689 let a = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]];
690
691 // Main diagonal
692 let main_diag = diagonal(a.view(), 0).unwrap();
693 assert_eq!(main_diag.len(), 3);
694 assert_eq!(main_diag[0], 1);
695 assert_eq!(main_diag[1], 5);
696 assert_eq!(main_diag[2], 9);
697
698 // Upper diagonal
699 let upper_diag = diagonal(a.view(), 1).unwrap();
700 assert_eq!(upper_diag.len(), 2);
701 assert_eq!(upper_diag[0], 2);
702 assert_eq!(upper_diag[1], 6);
703
704 // Lower diagonal
705 let lower_diag = diagonal(a.view(), -1).unwrap();
706 assert_eq!(lower_diag.len(), 2);
707 assert_eq!(lower_diag[0], 4);
708 assert_eq!(lower_diag[1], 8);
709 }
710
711 #[test]
712 fn test_where_1d() {
713 let a = array![1, 2, 3, 4, 5];
714
715 let result = where_1d(a.view(), |&x| x > 3).unwrap();
716 assert_eq!(result.len(), 2);
717 assert_eq!(result[0], 4);
718 assert_eq!(result[1], 5);
719 }
720
721 #[test]
722 fn test_where_2d() {
723 let a = array![[1, 2, 3], [4, 5, 6]];
724
725 let result = where_2d(a.view(), |&x| x > 3).unwrap();
726 assert_eq!(result.len(), 3);
727 assert_eq!(result[0], 4);
728 assert_eq!(result[1], 5);
729 assert_eq!(result[2], 6);
730 }
731
732 #[test]
733 fn test_indices_where_1d() {
734 let a = array![10, 20, 30, 40, 50];
735
736 let result = indices_where_1d(a.view(), |&x| x > 30).unwrap();
737 assert_eq!(result.len(), 2);
738 assert_eq!(result[0], 3);
739 assert_eq!(result[1], 4);
740 }
741
742 #[test]
743 fn test_indices_where_2d() {
744 let a = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]];
745
746 let (rows, cols) = indices_where_2d(a.view(), |&x| x > 5).unwrap();
747 assert_eq!(rows.len(), 4);
748 assert_eq!(cols.len(), 4);
749
750 // Verify that the indices correspond to the expected elements
751 for (r, c) in rows.iter().zip(cols.iter()) {
752 assert!(a[[*r, *c]] > 5);
753 }
754 }
755
756 #[test]
757 fn test_take_along_axis() {
758 let a = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]];
759 let indices = array![0, 2];
760
761 // Test along axis 0 (rows)
762 let result = take_along_axis(a.view(), indices.view(), 0).unwrap();
763 assert_eq!(result.shape(), &[2, 3]);
764 assert_eq!(result[[0, 0]], 1);
765 assert_eq!(result[[0, 1]], 2);
766 assert_eq!(result[[0, 2]], 3);
767 assert_eq!(result[[1, 0]], 7);
768 assert_eq!(result[[1, 1]], 8);
769 assert_eq!(result[[1, 2]], 9);
770 }
771}