train_station/tensor/indexing/
index_select.rs

1use crate::gradtrack::{is_grad_enabled, GradEngine, GradFn};
2use crate::tensor::core::Tensor;
3
4impl Tensor {
5    /// Select elements along a dimension using a list of indices
6    ///
7    /// This operation extracts elements from the input tensor along a specified dimension
8    /// using the provided indices. The output tensor has the same shape as the input
9    /// except along the specified dimension, where the size becomes the length of the
10    /// indices array.
11    ///
12    /// The index_select operation is commonly used for extracting specific rows, columns,
13    /// or slices from tensors, and is particularly useful in machine learning for
14    /// operations like embedding lookups and attention mechanisms.
15    ///
16    /// # Arguments
17    ///
18    /// * `dim` - The dimension along which to select elements (must be < tensor rank)
19    /// * `indices` - Array of indices specifying which elements to select along `dim`
20    ///
21    /// # Returns
22    ///
23    /// A new tensor with the same shape as the input except along `dim`, where the
24    /// size is `indices.len()`
25    ///
26    /// # Examples
27    ///
28    /// ## Basic Index Selection
29    ///
30    /// ```
31    /// use train_station::Tensor;
32    ///
33    /// // Create a 2x3 tensor: [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]
34    /// let tensor = Tensor::from_slice(&[0.0, 1.0, 2.0, 3.0, 4.0, 5.0], vec![2, 3]).unwrap();
35    ///
36    /// // Select columns 2 and 0 from dimension 1
37    /// let result = tensor.index_select(1, &[2, 0]);
38    ///
39    /// // Result shape is [2, 2] (same as input except dim 1 is now 2)
40    /// assert_eq!(result.shape().dims(), vec![2, 2]);
41    ///
42    /// // Row 0: selected columns [2, 0] -> [2.0, 0.0]
43    /// assert_eq!(result.get(&[0, 0]), 2.0);
44    /// assert_eq!(result.get(&[0, 1]), 0.0);
45    ///
46    /// // Row 1: selected columns [2, 0] -> [5.0, 3.0]
47    /// assert_eq!(result.get(&[1, 0]), 5.0);
48    /// assert_eq!(result.get(&[1, 1]), 3.0);
49    /// ```
50    ///
51    /// ## Index Selection with Gradient Tracking
52    ///
53    /// ```
54    /// use train_station::Tensor;
55    ///
56    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap()
57    ///     .with_requires_grad();
58    ///
59    /// // Select specific elements with gradient tracking enabled
60    /// let mut result = tensor.index_select(1, &[1, 2]);
61    /// result.backward(None);
62    ///
63    /// // Verify gradients are computed correctly
64    /// let grad = tensor.grad_owned().expect("gradient missing");
65    /// assert_eq!(grad.shape().dims(), vec![2, 3]);
66    /// ```
67    ///
68    /// ## Selecting Rows from a Matrix
69    ///
70    /// ```
71    /// use train_station::Tensor;
72    ///
73    /// // Create a 3x2 matrix
74    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![3, 2]).unwrap();
75    ///
76    /// // Select rows 2 and 0 (dimension 0)
77    /// let result = tensor.index_select(0, &[2, 0]);
78    ///
79    /// // Result shape is [2, 2]
80    /// assert_eq!(result.shape().dims(), vec![2, 2]);
81    ///
82    /// // Selected rows: row 2 [5.0, 6.0], row 0 [1.0, 2.0]
83    /// assert_eq!(result.get(&[0, 0]), 5.0); // First row of result (was row 2)
84    /// assert_eq!(result.get(&[0, 1]), 6.0);
85    /// assert_eq!(result.get(&[1, 0]), 1.0); // Second row of result (was row 0)
86    /// assert_eq!(result.get(&[1, 1]), 2.0);
87    /// ```
88    ///
89    /// # Performance Characteristics
90    ///
91    /// - **Time Complexity**: O(n) where n is the number of elements in the output tensor
92    /// - **Memory Usage**: Creates a new tensor with size equal to the output shape
93    /// - **Optimization**: Uses precomputed strides for efficient memory access
94    /// - **GradTrack Overhead**: Minimal overhead when gradient tracking is enabled
95    /// - **Memory Layout**: Output tensor is always contiguous for optimal performance
96    ///
97    /// # Implementation Details
98    ///
99    /// The index_select operation works by:
100    /// 1. Validating the dimension and index bounds
101    /// 2. Computing the output shape (same as input except along `dim`)
102    /// 3. Creating a new contiguous output tensor
103    /// 4. Iterating through all positions in the output tensor using nested loops:
104    ///    - Outer loop: iterate over dimensions before `dim`
105    ///    - Middle loop: iterate over the selected indices
106    ///    - Inner loop: iterate over dimensions after `dim`
107    /// 5. Computing source offsets using the input tensor's strides
108    /// 6. Copying values from input to output tensor
109    /// 7. Registering the operation for gradient computation if needed
110    ///
111    /// # Safety
112    ///
113    /// This function performs comprehensive bounds checking to ensure:
114    /// - The specified dimension is within the tensor's rank
115    /// - All indices are within bounds for the specified dimension
116    /// - Memory access is safe through proper offset calculations
117    ///
118    /// # Panics
119    ///
120    /// This function will panic if:
121    /// - `dim` is greater than or equal to the tensor's rank
122    /// - Any index in `indices` is out of bounds for the specified dimension
123    ///
124    /// # Thread Safety
125    ///
126    /// This function is thread-safe and can be called concurrently on different tensors.
127    /// The operation does not modify the input tensor and creates a new output tensor.
128    #[track_caller]
129    pub fn index_select(&self, dim: usize, indices: &[usize]) -> Tensor {
130        let rank = self.shape().rank();
131        assert!(
132            dim < rank,
133            "index_select dim {} out of bounds for rank {}",
134            dim,
135            rank
136        );
137        for &idx in indices {
138            assert!(
139                idx < self.shape().dims()[dim],
140                "index {} out of bounds for dimension {} (size {})",
141                idx,
142                dim,
143                self.shape().dims()[dim]
144            );
145        }
146
147        // Output shape is same as input except along dim -> indices.len()
148        let mut out_dims = self.shape().dims().to_vec();
149        out_dims[dim] = indices.len();
150        let mut output = Tensor::new(out_dims.clone());
151
152        // Precompute strides for fast offset computation
153        let in_strides = self.strides().to_vec();
154        let out_inner: usize = out_dims[dim + 1..].iter().product();
155        let out_outer: usize = out_dims[..dim].iter().product();
156
157        unsafe {
158            let dst_ptr = output.as_mut_ptr();
159            for outer_idx in 0..out_outer {
160                // Decode outer_idx into coordinates for axes < dim
161                let mut coords = vec![0usize; rank];
162                if dim > 0 {
163                    let mut tmp = outer_idx;
164                    for i in (0..dim).rev() {
165                        let s = self.shape().dims()[i];
166                        coords[i] = tmp % s;
167                        tmp /= s;
168                    }
169                }
170
171                for (j, &sel) in indices.iter().enumerate() {
172                    coords[dim] = sel;
173                    // Iterate over inner block
174                    for inner_idx in 0..out_inner {
175                        // Decode inner_idx into coordinates for axes > dim
176                        let mut tmp = inner_idx;
177                        for (i, c) in coords.iter_mut().enumerate().take(rank).skip(dim + 1) {
178                            let s = self.shape().dims()[i];
179                            *c = tmp % s;
180                            tmp /= s;
181                        }
182
183                        // Compute input offset via strides
184                        let mut src_off = 0usize;
185                        for i in 0..rank {
186                            src_off += coords[i] * in_strides[i];
187                        }
188
189                        // Destination offset within output tensor (contiguous)
190                        let out_block =
191                            outer_idx * (indices.len() * out_inner) + j * out_inner + inner_idx;
192                        *dst_ptr.add(out_block) = *self.as_ptr().add(src_off);
193                    }
194                }
195            }
196        }
197
198        // GradTrack registration
199        if self.requires_grad() && is_grad_enabled() {
200            output.set_requires_grad(true);
201            let grad_fn = GradFn::IndexSelect {
202                dim,
203                indices: indices.to_vec(),
204                input_shape: self.shape().dims().to_vec(),
205            };
206            output.set_grad_fn(grad_fn.clone());
207            GradEngine::register_operation(output.id(), vec![self.id()], grad_fn);
208        }
209
210        output
211    }
212}
213
214#[cfg(test)]
215mod tests {
216    use super::*;
217
218    #[test]
219    fn test_index_select_basic() {
220        let x =
221            Tensor::from_slice(&(0..6).map(|i| i as f32).collect::<Vec<_>>(), vec![2, 3]).unwrap();
222        let y = x.index_select(1, &[2, 0]);
223        assert_eq!(y.shape().dims(), vec![2, 2]);
224        assert_eq!(y.get(&[0, 0]), 2.0);
225        assert_eq!(y.get(&[0, 1]), 0.0);
226        assert_eq!(y.get(&[1, 0]), 5.0);
227        assert_eq!(y.get(&[1, 1]), 3.0);
228    }
229}