train_station/tensor/indexing/
index_select.rs

1use crate::gradtrack::{GradEngine, GradFn};
2use crate::tensor::core::Tensor;
3
4impl Tensor {
5    /// Select elements along a dimension using a list of indices
6    ///
7    /// This operation extracts elements from the input tensor along a specified dimension
8    /// using the provided indices. The output tensor has the same shape as the input
9    /// except along the specified dimension, where the size becomes the length of the
10    /// indices array.
11    ///
12    /// The index_select operation is commonly used for extracting specific rows, columns,
13    /// or slices from tensors, and is particularly useful in machine learning for
14    /// operations like embedding lookups and attention mechanisms.
15    ///
16    /// # Arguments
17    ///
18    /// * `dim` - The dimension along which to select elements (must be < tensor rank)
19    /// * `indices` - Array of indices specifying which elements to select along `dim`
20    ///
21    /// # Returns
22    ///
23    /// A new tensor with the same shape as the input except along `dim`, where the
24    /// size is `indices.len()`
25    ///
26    /// # Examples
27    ///
28    /// ## Basic Index Selection
29    ///
30    /// ```
31    /// use train_station::Tensor;
32    ///
33    /// // Create a 2x3 tensor: [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]
34    /// let tensor = Tensor::from_slice(&[0.0, 1.0, 2.0, 3.0, 4.0, 5.0], vec![2, 3]).unwrap();
35    ///
36    /// // Select columns 2 and 0 from dimension 1
37    /// let result = tensor.index_select(1, &[2, 0]);
38    ///
39    /// // Result shape is [2, 2] (same as input except dim 1 is now 2)
40    /// assert_eq!(result.shape().dims, vec![2, 2]);
41    ///
42    /// // Row 0: selected columns [2, 0] -> [2.0, 0.0]
43    /// assert_eq!(result.get(&[0, 0]), 2.0);
44    /// assert_eq!(result.get(&[0, 1]), 0.0);
45    ///
46    /// // Row 1: selected columns [2, 0] -> [5.0, 3.0]
47    /// assert_eq!(result.get(&[1, 0]), 5.0);
48    /// assert_eq!(result.get(&[1, 1]), 3.0);
49    /// ```
50    ///
51    /// ## Index Selection with Gradient Tracking
52    ///
53    /// ```
54    /// use train_station::Tensor;
55    ///
56    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap()
57    ///     .with_requires_grad();
58    ///
59    /// // Select specific elements with gradient tracking enabled
60    /// let mut result = tensor.index_select(1, &[1, 2]);
61    /// result.backward(None);
62    ///
63    /// // Verify gradients are computed correctly
64    /// let grad = tensor.grad_by_value().expect("gradient missing");
65    /// assert_eq!(grad.shape().dims, vec![2, 3]);
66    /// ```
67    ///
68    /// ## Selecting Rows from a Matrix
69    ///
70    /// ```
71    /// use train_station::Tensor;
72    ///
73    /// // Create a 3x2 matrix
74    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![3, 2]).unwrap();
75    ///
76    /// // Select rows 2 and 0 (dimension 0)
77    /// let result = tensor.index_select(0, &[2, 0]);
78    ///
79    /// // Result shape is [2, 2]
80    /// assert_eq!(result.shape().dims, vec![2, 2]);
81    ///
82    /// // Selected rows: row 2 [5.0, 6.0], row 0 [1.0, 2.0]
83    /// assert_eq!(result.get(&[0, 0]), 5.0); // First row of result (was row 2)
84    /// assert_eq!(result.get(&[0, 1]), 6.0);
85    /// assert_eq!(result.get(&[1, 0]), 1.0); // Second row of result (was row 0)
86    /// assert_eq!(result.get(&[1, 1]), 2.0);
87    /// ```
88    ///
89    /// # Performance Characteristics
90    ///
91    /// - **Time Complexity**: O(n) where n is the number of elements in the output tensor
92    /// - **Memory Usage**: Creates a new tensor with size equal to the output shape
93    /// - **Optimization**: Uses precomputed strides for efficient memory access
94    /// - **GradTrack Overhead**: Minimal overhead when gradient tracking is enabled
95    /// - **Memory Layout**: Output tensor is always contiguous for optimal performance
96    ///
97    /// # Implementation Details
98    ///
99    /// The index_select operation works by:
100    /// 1. Validating the dimension and index bounds
101    /// 2. Computing the output shape (same as input except along `dim`)
102    /// 3. Creating a new contiguous output tensor
103    /// 4. Iterating through all positions in the output tensor using nested loops:
104    ///    - Outer loop: iterate over dimensions before `dim`
105    ///    - Middle loop: iterate over the selected indices
106    ///    - Inner loop: iterate over dimensions after `dim`
107    /// 5. Computing source offsets using the input tensor's strides
108    /// 6. Copying values from input to output tensor
109    /// 7. Registering the operation for gradient computation if needed
110    ///
111    /// # Safety
112    ///
113    /// This function performs comprehensive bounds checking to ensure:
114    /// - The specified dimension is within the tensor's rank
115    /// - All indices are within bounds for the specified dimension
116    /// - Memory access is safe through proper offset calculations
117    ///
118    /// # Panics
119    ///
120    /// This function will panic if:
121    /// - `dim` is greater than or equal to the tensor's rank
122    /// - Any index in `indices` is out of bounds for the specified dimension
123    ///
124    /// # Thread Safety
125    ///
126    /// This function is thread-safe and can be called concurrently on different tensors.
127    /// The operation does not modify the input tensor and creates a new output tensor.
128    pub fn index_select(&self, dim: usize, indices: &[usize]) -> Tensor {
129        let rank = self.shape().rank();
130        assert!(
131            dim < rank,
132            "index_select dim {} out of bounds for rank {}",
133            dim,
134            rank
135        );
136        for &idx in indices {
137            assert!(
138                idx < self.shape().dims[dim],
139                "index {} out of bounds for dimension {} (size {})",
140                idx,
141                dim,
142                self.shape().dims[dim]
143            );
144        }
145
146        // Output shape is same as input except along dim -> indices.len()
147        let mut out_dims = self.shape().dims.clone();
148        out_dims[dim] = indices.len();
149        let mut output = Tensor::new(out_dims.clone());
150
151        // Precompute strides for fast offset computation
152        let in_strides = self.strides().to_vec();
153        let out_inner: usize = out_dims[dim + 1..].iter().product();
154        let out_outer: usize = out_dims[..dim].iter().product();
155
156        unsafe {
157            let dst_ptr = output.as_mut_ptr();
158            for outer_idx in 0..out_outer {
159                // Decode outer_idx into coordinates for axes < dim
160                let mut coords = vec![0usize; rank];
161                if dim > 0 {
162                    let mut tmp = outer_idx;
163                    for i in (0..dim).rev() {
164                        let s = self.shape().dims[i];
165                        coords[i] = tmp % s;
166                        tmp /= s;
167                    }
168                }
169
170                for (j, &sel) in indices.iter().enumerate() {
171                    coords[dim] = sel;
172                    // Iterate over inner block
173                    for inner_idx in 0..out_inner {
174                        // Decode inner_idx into coordinates for axes > dim
175                        let mut tmp = inner_idx;
176                        for (i, c) in coords.iter_mut().enumerate().take(rank).skip(dim + 1) {
177                            let s = self.shape().dims[i];
178                            *c = tmp % s;
179                            tmp /= s;
180                        }
181
182                        // Compute input offset via strides
183                        let mut src_off = 0usize;
184                        for i in 0..rank {
185                            src_off += coords[i] * in_strides[i];
186                        }
187
188                        // Destination offset within output tensor (contiguous)
189                        let out_block =
190                            outer_idx * (indices.len() * out_inner) + j * out_inner + inner_idx;
191                        *dst_ptr.add(out_block) = *self.as_ptr().add(src_off);
192                    }
193                }
194            }
195        }
196
197        // GradTrack registration
198        if self.requires_grad() {
199            output.set_requires_grad(true);
200            let grad_fn = GradFn::IndexSelect {
201                dim,
202                indices: indices.to_vec(),
203                input_shape: self.shape().dims.clone(),
204            };
205            output.set_grad_fn(grad_fn.clone());
206            GradEngine::register_operation(output.id(), vec![self.id()], grad_fn);
207        }
208
209        output
210    }
211}
212
213#[cfg(test)]
214mod tests {
215    use super::*;
216
217    #[test]
218    fn test_index_select_basic() {
219        let x =
220            Tensor::from_slice(&(0..6).map(|i| i as f32).collect::<Vec<_>>(), vec![2, 3]).unwrap();
221        let y = x.index_select(1, &[2, 0]);
222        assert_eq!(y.shape().dims, vec![2, 2]);
223        assert_eq!(y.get(&[0, 0]), 2.0);
224        assert_eq!(y.get(&[0, 1]), 0.0);
225        assert_eq!(y.get(&[1, 0]), 5.0);
226        assert_eq!(y.get(&[1, 1]), 3.0);
227    }
228}