scirs2_core/ndarray_ext/
indexing.rs

1//! Advanced indexing operations for ndarray
2//!
3//! This module provides enhanced indexing capabilities including boolean
4//! masking, fancy indexing, and advanced slicing operations similar to
5//! `NumPy`'s advanced indexing functionality.
6
7use ndarray::{Array, ArrayView, Ix1, Ix2};
8
9/// Result type for coordinating indices
10pub type IndicesResult = Result<(Array<usize, Ix1>, Array<usize, Ix1>), &'static str>;
11
12/// Take elements from a 2D array along a given axis using indices from another array
13///
14/// # Arguments
15///
16/// * `array` - The input 2D array
17/// * `indices` - Array of indices to take
18/// * `axis` - The axis along which to take values (0 for rows, 1 for columns)
19///
20/// # Returns
21///
22/// A 2D array of values at the specified indices along the given axis
23///
24/// # Examples
25///
26/// ```
27/// use ndarray::array;
28/// use scirs2_core::ndarray_ext::indexing::take_2d;
29///
30/// let a = array![[1, 2, 3], [4, 5, 6]];
31/// let indices = array![0, 2];
32/// let result = take_2d(a.view(), indices.view(), 1).unwrap();
33/// assert_eq!(result.shape(), &[2, 2]);
34/// assert_eq!(result[[0, 0]], 1);
35/// assert_eq!(result[[0, 1]], 3);
36/// ```
37#[allow(dead_code)]
38pub fn take_2d<T>(
39    array: ArrayView<T, Ix2>,
40    indices: ArrayView<usize, Ix1>,
41    axis: usize,
42) -> Result<Array<T, Ix2>, &'static str>
43where
44    T: Clone + Default,
45{
46    let (rows, cols) = (array.shape()[0], array.shape()[1]);
47    let axis_len = if axis == 0 { rows } else { cols };
48
49    // Check that indices are within bounds
50    for &idx in indices.iter() {
51        if idx >= axis_len {
52            return Err("Index out of bounds");
53        }
54    }
55
56    // Create the result array with the appropriate shape
57    let (result_rows, result_cols) = match axis {
58        0 => (indices.len(), cols),
59        1 => (rows, indices.len()),
60        _ => return Err("Axis must be 0 or 1 for 2D arrays"),
61    };
62
63    let mut result = Array::<T, Ix2>::default((result_rows, result_cols));
64
65    // Fill the result array
66    match axis {
67        0 => {
68            // Take along rows
69            for (i, &idx) in indices.iter().enumerate() {
70                for j in 0..cols {
71                    result[[i, j]] = array[[idx, j]].clone();
72                }
73            }
74        }
75        1 => {
76            // Take along columns
77            for i in 0..rows {
78                for (j, &idx) in indices.iter().enumerate() {
79                    result[[i, j]] = array[[i, idx]].clone();
80                }
81            }
82        }
83        _ => unreachable!(),
84    }
85
86    Ok(result)
87}
88
89/// Boolean mask indexing for 2D arrays
90///
91/// # Arguments
92///
93/// * `array` - The input 2D array
94/// * `mask` - Boolean mask of the same shape as the array
95///
96/// # Returns
97///
98/// A 1D array containing the elements where the mask is true
99///
100/// # Examples
101///
102/// ```
103/// use ndarray::array;
104/// use scirs2_core::ndarray_ext::indexing::boolean_mask_2d;
105///
106/// let a = array![[1, 2, 3], [4, 5, 6]];
107/// let mask = array![[true, false, true], [false, true, false]];
108/// let result = boolean_mask_2d(a.view(), mask.view()).unwrap();
109/// assert_eq!(result.len(), 3);
110/// assert_eq!(result[0], 1);
111/// assert_eq!(result[1], 3);
112/// assert_eq!(result[2], 5);
113/// ```
114#[allow(dead_code)]
115pub fn boolean_mask_2d<T>(
116    array: ArrayView<T, Ix2>,
117    mask: ArrayView<bool, Ix2>,
118) -> Result<Array<T, Ix1>, &'static str>
119where
120    T: Clone + Default,
121{
122    // Check that the mask has the same shape as the array
123    if array.shape() != mask.shape() {
124        return Err("Mask shape must match array shape");
125    }
126
127    // Count the number of true values in the mask
128    let true_count = mask.iter().filter(|&&x| x).count();
129
130    // Create the result array
131    let mut result = Array::<T, Ix1>::default(true_count);
132
133    // Fill the result array with elements where the mask is true
134    let mut idx = 0;
135    for (val, &m) in array.iter().zip(mask.iter()) {
136        if m {
137            result[idx] = val.clone();
138            idx += 1;
139        }
140    }
141
142    Ok(result)
143}
144
145/// Boolean mask indexing for 1D arrays
146///
147/// # Arguments
148///
149/// * `array` - The input 1D array
150/// * `mask` - Boolean mask of the same shape as the array
151///
152/// # Returns
153///
154/// A 1D array containing the elements where the mask is true
155///
156/// # Examples
157///
158/// ```
159/// use ndarray::array;
160/// use scirs2_core::ndarray_ext::indexing::boolean_mask_1d;
161///
162/// let a = array![1, 2, 3, 4, 5];
163/// let mask = array![true, false, true, false, true];
164/// let result = boolean_mask_1d(a.view(), mask.view()).unwrap();
165/// assert_eq!(result.len(), 3);
166/// assert_eq!(result[0], 1);
167/// assert_eq!(result[1], 3);
168/// assert_eq!(result[2], 5);
169/// ```
170#[allow(dead_code)]
171pub fn boolean_mask_1d<T>(
172    array: ArrayView<T, Ix1>,
173    mask: ArrayView<bool, Ix1>,
174) -> Result<Array<T, Ix1>, &'static str>
175where
176    T: Clone + Default,
177{
178    // Check that the mask has the same shape as the array
179    if array.shape() != mask.shape() {
180        return Err("Mask shape must match array shape");
181    }
182
183    // Count the number of true values in the mask
184    let true_count = mask.iter().filter(|&&x| x).count();
185
186    // Create the result array
187    let mut result = Array::<T, Ix1>::default(true_count);
188
189    // Fill the result array with elements where the mask is true
190    let mut idx = 0;
191    for (val, &m) in array.iter().zip(mask.iter()) {
192        if m {
193            result[idx] = val.clone();
194            idx += 1;
195        }
196    }
197
198    Ok(result)
199}
200
201/// Indexed slicing for 1D arrays
202///
203/// # Arguments
204///
205/// * `array` - The input 1D array
206/// * `indices` - Array of indices to extract
207///
208/// # Returns
209///
210/// A 1D array containing the elements at the specified indices
211///
212/// # Examples
213///
214/// ```
215/// use ndarray::array;
216/// use scirs2_core::ndarray_ext::indexing::take_1d;
217///
218/// let a = array![10, 20, 30, 40, 50];
219/// let indices = array![0, 2, 4];
220/// let result = take_1d(a.view(), indices.view()).unwrap();
221/// assert_eq!(result.len(), 3);
222/// assert_eq!(result[0], 10);
223/// assert_eq!(result[1], 30);
224/// assert_eq!(result[2], 50);
225/// ```
226#[allow(dead_code)]
227pub fn take_1d<T>(
228    array: ArrayView<T, Ix1>,
229    indices: ArrayView<usize, Ix1>,
230) -> Result<Array<T, Ix1>, &'static str>
231where
232    T: Clone + Default,
233{
234    let len = array.len();
235
236    // Verify that indices are in bounds
237    for &idx in indices.iter() {
238        if idx >= len {
239            return Err("Index out of bounds");
240        }
241    }
242
243    // Create the result array
244    let mut result = Array::<T, Ix1>::default(indices.len());
245
246    // Extract the elements at the specified indices
247    for (i, &idx) in indices.iter().enumerate() {
248        result[i] = array[idx].clone();
249    }
250
251    Ok(result)
252}
253
254/// Fancy indexing for 2D arrays with pairs of index arrays
255///
256/// # Arguments
257///
258/// * `array` - The input 2D array
259/// * `row_indices` - Array of row indices
260/// * `col_indices` - Array of column indices
261///
262/// # Returns
263///
264/// A 1D array containing the elements at the specified indices
265///
266/// # Examples
267///
268/// ```
269/// use ndarray::array;
270/// use scirs2_core::ndarray_ext::indexing::fancy_index_2d;
271///
272/// let a = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]];
273/// let row_indices = array![0, 1, 2];
274/// let col_indices = array![0, 1, 2];
275/// let result = fancy_index_2d(a.view(), row_indices.view(), col_indices.view()).unwrap();
276/// assert_eq!(result.len(), 3);
277/// assert_eq!(result[0], 1);
278/// assert_eq!(result[1], 5);
279/// assert_eq!(result[2], 9);
280/// ```
281#[allow(dead_code)]
282pub fn fancy_index_2d<T>(
283    array: ArrayView<T, Ix2>,
284    row_indices: ArrayView<usize, Ix1>,
285    col_indices: ArrayView<usize, Ix1>,
286) -> Result<Array<T, Ix1>, &'static str>
287where
288    T: Clone + Default,
289{
290    // Check that all index arrays have the same length
291    let result_size = row_indices.len();
292    if col_indices.len() != result_size {
293        return Err("Row and column index arrays must have the same length");
294    }
295
296    let (rows, cols) = (array.shape()[0], array.shape()[1]);
297
298    // Check that _indices are within bounds
299    for &idx in row_indices.iter() {
300        if idx >= rows {
301            return Err("Row index out of bounds");
302        }
303    }
304
305    for &idx in col_indices.iter() {
306        if idx >= cols {
307            return Err("Column index out of bounds");
308        }
309    }
310
311    // Create the result array
312    let mut result = Array::<T, Ix1>::default(result_size);
313
314    // Fill the result array
315    for i in 0..result_size {
316        let row = row_indices[i];
317        let col = col_indices[i];
318        result[i] = array[[row, col]].clone();
319    }
320
321    Ok(result)
322}
323
324/// Extract a diagonal or a sub-diagonal from a 2D array
325///
326/// # Arguments
327///
328/// * `array` - The input 2D array
329/// * `offset` - Offset from the main diagonal (0 for main diagonal, positive for above, negative for below)
330///
331/// # Returns
332///
333/// A 1D array containing the diagonal elements
334///
335/// # Examples
336///
337/// ```
338/// use ndarray::array;
339/// use scirs2_core::ndarray_ext::indexing::diagonal;
340///
341/// let a = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]];
342///
343/// // Main diagonal
344/// let main_diag = diagonal(a.view(), 0).unwrap();
345/// assert_eq!(main_diag.len(), 3);
346/// assert_eq!(main_diag[0], 1);
347/// assert_eq!(main_diag[1], 5);
348/// assert_eq!(main_diag[2], 9);
349///
350/// // Upper diagonal
351/// let upper_diag = diagonal(a.view(), 1).unwrap();
352/// assert_eq!(upper_diag.len(), 2);
353/// assert_eq!(upper_diag[0], 2);
354/// assert_eq!(upper_diag[1], 6);
355///
356/// // Lower diagonal
357/// let lower_diag = diagonal(a.view(), -1).unwrap();
358/// assert_eq!(lower_diag.len(), 2);
359/// assert_eq!(lower_diag[0], 4);
360/// assert_eq!(lower_diag[1], 8);
361/// ```
362#[allow(dead_code)]
363pub fn diagonal<T>(array: ArrayView<T, Ix2>, offset: isize) -> Result<Array<T, Ix1>, &'static str>
364where
365    T: Clone + Default,
366{
367    let (rows, cols) = (array.shape()[0], array.shape()[1]);
368
369    // Calculate the length of the diagonal
370    let diag_len = if offset >= 0 {
371        std::cmp::min(rows, cols.saturating_sub(offset as usize))
372    } else {
373        std::cmp::min(cols, rows.saturating_sub((-offset) as usize))
374    };
375
376    if diag_len == 0 {
377        return Err("No diagonal elements for the given offset");
378    }
379
380    // Create the result _array
381    let mut result = Array::<T, Ix1>::default(diag_len);
382
383    // Extract the diagonal elements
384    for i in 0..diag_len {
385        let row = if offset < 0 {
386            i + (-offset) as usize
387        } else {
388            i
389        };
390
391        let col = if offset > 0 { i + offset as usize } else { i };
392
393        result[i] = array[[row, col]].clone();
394    }
395
396    Ok(result)
397}
398
399/// Where function - select elements based on a condition for 1D arrays
400///
401/// # Arguments
402///
403/// * `array` - The input 1D array
404/// * `condition` - A function that takes a reference to an element and returns a bool
405///
406/// # Returns
407///
408/// A 1D array containing the elements where the condition is true
409///
410/// # Examples
411///
412/// ```
413/// use ndarray::array;
414/// use scirs2_core::ndarray_ext::indexing::where_1d;
415///
416/// let a = array![1, 2, 3, 4, 5];
417/// let result = where_1d(a.view(), |&x| x > 3).unwrap();
418/// assert_eq!(result.len(), 2);
419/// assert_eq!(result[0], 4);
420/// assert_eq!(result[1], 5);
421/// ```
422#[allow(dead_code)]
423pub fn where_1d<T, F>(array: ArrayView<T, Ix1>, condition: F) -> Result<Array<T, Ix1>, &'static str>
424where
425    T: Clone + Default,
426    F: Fn(&T) -> bool,
427{
428    // Build a boolean mask _array based on the condition
429    let mask = array.map(condition);
430
431    // Use the boolean_mask_1d function to select elements
432    boolean_mask_1d(array, mask.view())
433}
434
435/// Where function - select elements based on a condition for 2D arrays
436///
437/// # Arguments
438///
439/// * `array` - The input 2D array
440/// * `condition` - A function that takes a reference to an element and returns a bool
441///
442/// # Returns
443///
444/// A 1D array containing the elements where the condition is true
445///
446/// # Examples
447///
448/// ```
449/// use ndarray::array;
450/// use scirs2_core::ndarray_ext::indexing::where_2d;
451///
452/// let a = array![[1, 2, 3], [4, 5, 6]];
453/// let result = where_2d(a.view(), |&x| x > 3).unwrap();
454/// assert_eq!(result.len(), 3);
455/// assert_eq!(result[0], 4);
456/// assert_eq!(result[1], 5);
457/// assert_eq!(result[2], 6);
458/// ```
459#[allow(dead_code)]
460pub fn where_2d<T, F>(array: ArrayView<T, Ix2>, condition: F) -> Result<Array<T, Ix1>, &'static str>
461where
462    T: Clone + Default,
463    F: Fn(&T) -> bool,
464{
465    // Build a boolean mask _array based on the condition
466    let mask = array.map(condition);
467
468    // Use the boolean_mask_2d function to select elements
469    boolean_mask_2d(array, mask.view())
470}
471
472/// Extract indices where a 1D array meets a condition
473///
474/// # Arguments
475///
476/// * `array` - The input 1D array
477/// * `condition` - A function that takes a reference to an element and returns a bool
478///
479/// # Returns
480///
481/// A 1D array of indices where the condition is true
482///
483/// # Examples
484///
485/// ```
486/// use ndarray::array;
487/// use scirs2_core::ndarray_ext::indexing::indices_where_1d;
488///
489/// let a = array![10, 20, 30, 40, 50];
490/// let result = indices_where_1d(a.view(), |&x| x > 30).unwrap();
491/// assert_eq!(result.len(), 2);
492/// assert_eq!(result[0], 3);
493/// assert_eq!(result[1], 4);
494/// ```
495#[allow(dead_code)]
496pub fn indices_where_1d<T, F>(
497    array: ArrayView<T, Ix1>,
498    condition: F,
499) -> Result<Array<usize, Ix1>, &'static str>
500where
501    T: Clone,
502    F: Fn(&T) -> bool,
503{
504    // Build a vector of indices where the condition is true
505    let mut indices = Vec::new();
506
507    for (i, val) in array.iter().enumerate() {
508        if condition(val) {
509            indices.push(i);
510        }
511    }
512
513    // Convert the vector to an ndarray Array
514    Ok(Array::from_vec(indices))
515}
516
517/// Extract indices where a 2D array meets a condition
518///
519/// # Arguments
520///
521/// * `array` - The input 2D array
522/// * `condition` - A function that takes a reference to an element and returns a bool
523///
524/// # Returns
525///
526/// A tuple of two 1D arrays (row_indices, col_indices) where the condition is true
527///
528/// # Examples
529///
530/// ```
531/// use ndarray::array;
532/// use scirs2_core::ndarray_ext::indexing::indices_where_2d;
533///
534/// let a = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]];
535/// let (rows, cols) = indices_where_2d(a.view(), |&x| x > 5).unwrap();
536/// assert_eq!(rows.len(), 4);
537/// assert_eq!(cols.len(), 4);
538/// // The indices correspond to elements: 6, 7, 8, 9
539/// ```
540#[allow(dead_code)]
541pub fn indices_where_2d<T, F>(array: ArrayView<T, Ix2>, condition: F) -> IndicesResult
542where
543    T: Clone,
544    F: Fn(&T) -> bool,
545{
546    let (rows, cols) = (array.shape()[0], array.shape()[1]);
547
548    // Build vectors of row and column indices where the condition is true
549    let mut row_indices = Vec::new();
550    let mut col_indices = Vec::new();
551
552    for r in 0..rows {
553        for c in 0..cols {
554            if condition(&array[[r, c]]) {
555                row_indices.push(r);
556                col_indices.push(c);
557            }
558        }
559    }
560
561    // Convert the vectors to ndarray Arrays
562    Ok((Array::from_vec(row_indices), Array::from_vec(col_indices)))
563}
564
565/// Return elements from a 2D array along an axis at specified indices
566///
567/// # Arguments
568///
569/// * `array` - The input 2D array
570/// * `indices` - Indices to take along the specified axis
571/// * `axis` - The axis along which to take values (0 for rows, 1 for columns)
572///
573/// # Returns
574///
575/// A 2D array with selected slices from the original array
576///
577/// # Examples
578///
579/// ```
580/// use ndarray::array;
581/// use scirs2_core::ndarray_ext::indexing::take_along_axis;
582///
583/// let a = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]];
584/// let indices = array![0, 2];
585///
586/// // Take rows 0 and 2
587/// let result = take_along_axis(a.view(), indices.view(), 0).unwrap();
588/// assert_eq!(result.shape(), &[2, 3]);
589/// assert_eq!(result[[0, 0]], 1);
590/// assert_eq!(result[[0, 1]], 2);
591/// assert_eq!(result[[0, 2]], 3);
592/// assert_eq!(result[[1, 0]], 7);
593/// assert_eq!(result[[1, 1]], 8);
594/// assert_eq!(result[[1, 2]], 9);
595/// ```
596#[allow(dead_code)]
597pub fn take_along_axis<T>(
598    array: ArrayView<T, Ix2>,
599    indices: ArrayView<usize, Ix1>,
600    axis: usize,
601) -> Result<Array<T, Ix2>, &'static str>
602where
603    T: Clone + Default,
604{
605    take_2d(array, indices, axis)
606}
607
608#[cfg(test)]
609mod tests {
610    use super::*;
611    use ndarray::array;
612
613    #[test]
614    fn test_boolean_mask_1d() {
615        let a = array![1, 2, 3, 4, 5];
616        let mask = array![true, false, true, false, true];
617
618        let result = boolean_mask_1d(a.view(), mask.view()).unwrap();
619        assert_eq!(result.len(), 3);
620        assert_eq!(result[0], 1);
621        assert_eq!(result[1], 3);
622        assert_eq!(result[2], 5);
623    }
624
625    #[test]
626    fn test_boolean_mask_2d() {
627        let a = array![[1, 2, 3], [4, 5, 6]];
628        let mask = array![[true, false, true], [false, true, false]];
629
630        let result = boolean_mask_2d(a.view(), mask.view()).unwrap();
631        assert_eq!(result.len(), 3);
632        assert_eq!(result[0], 1);
633        assert_eq!(result[1], 3);
634        assert_eq!(result[2], 5);
635    }
636
637    #[test]
638    fn test_take_1d() {
639        let a = array![10, 20, 30, 40, 50];
640        let indices = array![0, 2, 4];
641
642        let result = take_1d(a.view(), indices.view()).unwrap();
643        assert_eq!(result.len(), 3);
644        assert_eq!(result[0], 10);
645        assert_eq!(result[1], 30);
646        assert_eq!(result[2], 50);
647    }
648
649    #[test]
650    fn test_take_2d() {
651        let a = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]];
652        let indices = array![0, 2];
653
654        // Take along axis 0 (rows)
655        let result = take_2d(a.view(), indices.view(), 0).unwrap();
656        assert_eq!(result.shape(), &[2, 3]);
657        assert_eq!(result[[0, 0]], 1);
658        assert_eq!(result[[0, 1]], 2);
659        assert_eq!(result[[0, 2]], 3);
660        assert_eq!(result[[1, 0]], 7);
661        assert_eq!(result[[1, 1]], 8);
662        assert_eq!(result[[1, 2]], 9);
663
664        // Take along axis 1 (columns)
665        let result = take_2d(a.view(), indices.view(), 1).unwrap();
666        assert_eq!(result.shape(), &[3, 2]);
667        assert_eq!(result[[0, 0]], 1);
668        assert_eq!(result[[0, 1]], 3);
669        assert_eq!(result[[1, 0]], 4);
670        assert_eq!(result[[1, 1]], 6);
671        assert_eq!(result[[2, 0]], 7);
672        assert_eq!(result[[2, 1]], 9);
673    }
674
675    #[test]
676    fn test_fancy_index_2d() {
677        let a = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]];
678        let row_indices = array![0, 2];
679        let col_indices = array![0, 1];
680
681        let result = fancy_index_2d(a.view(), row_indices.view(), col_indices.view()).unwrap();
682        assert_eq!(result.len(), 2);
683        assert_eq!(result[0], 1);
684        assert_eq!(result[1], 8);
685    }
686
687    #[test]
688    fn test_diagonal() {
689        let a = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]];
690
691        // Main diagonal
692        let main_diag = diagonal(a.view(), 0).unwrap();
693        assert_eq!(main_diag.len(), 3);
694        assert_eq!(main_diag[0], 1);
695        assert_eq!(main_diag[1], 5);
696        assert_eq!(main_diag[2], 9);
697
698        // Upper diagonal
699        let upper_diag = diagonal(a.view(), 1).unwrap();
700        assert_eq!(upper_diag.len(), 2);
701        assert_eq!(upper_diag[0], 2);
702        assert_eq!(upper_diag[1], 6);
703
704        // Lower diagonal
705        let lower_diag = diagonal(a.view(), -1).unwrap();
706        assert_eq!(lower_diag.len(), 2);
707        assert_eq!(lower_diag[0], 4);
708        assert_eq!(lower_diag[1], 8);
709    }
710
711    #[test]
712    fn test_where_1d() {
713        let a = array![1, 2, 3, 4, 5];
714
715        let result = where_1d(a.view(), |&x| x > 3).unwrap();
716        assert_eq!(result.len(), 2);
717        assert_eq!(result[0], 4);
718        assert_eq!(result[1], 5);
719    }
720
721    #[test]
722    fn test_where_2d() {
723        let a = array![[1, 2, 3], [4, 5, 6]];
724
725        let result = where_2d(a.view(), |&x| x > 3).unwrap();
726        assert_eq!(result.len(), 3);
727        assert_eq!(result[0], 4);
728        assert_eq!(result[1], 5);
729        assert_eq!(result[2], 6);
730    }
731
732    #[test]
733    fn test_indices_where_1d() {
734        let a = array![10, 20, 30, 40, 50];
735
736        let result = indices_where_1d(a.view(), |&x| x > 30).unwrap();
737        assert_eq!(result.len(), 2);
738        assert_eq!(result[0], 3);
739        assert_eq!(result[1], 4);
740    }
741
742    #[test]
743    fn test_indices_where_2d() {
744        let a = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]];
745
746        let (rows, cols) = indices_where_2d(a.view(), |&x| x > 5).unwrap();
747        assert_eq!(rows.len(), 4);
748        assert_eq!(cols.len(), 4);
749
750        // Verify that the indices correspond to the expected elements
751        for (r, c) in rows.iter().zip(cols.iter()) {
752            assert!(a[[*r, *c]] > 5);
753        }
754    }
755
756    #[test]
757    fn test_take_along_axis() {
758        let a = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]];
759        let indices = array![0, 2];
760
761        // Test along axis 0 (rows)
762        let result = take_along_axis(a.view(), indices.view(), 0).unwrap();
763        assert_eq!(result.shape(), &[2, 3]);
764        assert_eq!(result[[0, 0]], 1);
765        assert_eq!(result[[0, 1]], 2);
766        assert_eq!(result[[0, 2]], 3);
767        assert_eq!(result[[1, 0]], 7);
768        assert_eq!(result[[1, 1]], 8);
769        assert_eq!(result[[1, 2]], 9);
770    }
771}