train_station/tensor/indexing/index_select.rs
1use crate::gradtrack::{GradEngine, GradFn};
2use crate::tensor::core::Tensor;
3
4impl Tensor {
5 /// Select elements along a dimension using a list of indices
6 ///
7 /// This operation extracts elements from the input tensor along a specified dimension
8 /// using the provided indices. The output tensor has the same shape as the input
9 /// except along the specified dimension, where the size becomes the length of the
10 /// indices array.
11 ///
12 /// The index_select operation is commonly used for extracting specific rows, columns,
13 /// or slices from tensors, and is particularly useful in machine learning for
14 /// operations like embedding lookups and attention mechanisms.
15 ///
16 /// # Arguments
17 ///
18 /// * `dim` - The dimension along which to select elements (must be < tensor rank)
19 /// * `indices` - Array of indices specifying which elements to select along `dim`
20 ///
21 /// # Returns
22 ///
23 /// A new tensor with the same shape as the input except along `dim`, where the
24 /// size is `indices.len()`
25 ///
26 /// # Examples
27 ///
28 /// ## Basic Index Selection
29 ///
30 /// ```
31 /// use train_station::Tensor;
32 ///
33 /// // Create a 2x3 tensor: [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]
34 /// let tensor = Tensor::from_slice(&[0.0, 1.0, 2.0, 3.0, 4.0, 5.0], vec![2, 3]).unwrap();
35 ///
36 /// // Select columns 2 and 0 from dimension 1
37 /// let result = tensor.index_select(1, &[2, 0]);
38 ///
39 /// // Result shape is [2, 2] (same as input except dim 1 is now 2)
40 /// assert_eq!(result.shape().dims(), vec![2, 2]);
41 ///
42 /// // Row 0: selected columns [2, 0] -> [2.0, 0.0]
43 /// assert_eq!(result.get(&[0, 0]), 2.0);
44 /// assert_eq!(result.get(&[0, 1]), 0.0);
45 ///
46 /// // Row 1: selected columns [2, 0] -> [5.0, 3.0]
47 /// assert_eq!(result.get(&[1, 0]), 5.0);
48 /// assert_eq!(result.get(&[1, 1]), 3.0);
49 /// ```
50 ///
51 /// ## Index Selection with Gradient Tracking
52 ///
53 /// ```
54 /// use train_station::Tensor;
55 ///
56 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap()
57 /// .with_requires_grad();
58 ///
59 /// // Select specific elements with gradient tracking enabled
60 /// let mut result = tensor.index_select(1, &[1, 2]);
61 /// result.backward(None);
62 ///
63 /// // Verify gradients are computed correctly
64 /// let grad = tensor.grad_owned().expect("gradient missing");
65 /// assert_eq!(grad.shape().dims(), vec![2, 3]);
66 /// ```
67 ///
68 /// ## Selecting Rows from a Matrix
69 ///
70 /// ```
71 /// use train_station::Tensor;
72 ///
73 /// // Create a 3x2 matrix
74 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![3, 2]).unwrap();
75 ///
76 /// // Select rows 2 and 0 (dimension 0)
77 /// let result = tensor.index_select(0, &[2, 0]);
78 ///
79 /// // Result shape is [2, 2]
80 /// assert_eq!(result.shape().dims(), vec![2, 2]);
81 ///
82 /// // Selected rows: row 2 [5.0, 6.0], row 0 [1.0, 2.0]
83 /// assert_eq!(result.get(&[0, 0]), 5.0); // First row of result (was row 2)
84 /// assert_eq!(result.get(&[0, 1]), 6.0);
85 /// assert_eq!(result.get(&[1, 0]), 1.0); // Second row of result (was row 0)
86 /// assert_eq!(result.get(&[1, 1]), 2.0);
87 /// ```
88 ///
89 /// # Performance Characteristics
90 ///
91 /// - **Time Complexity**: O(n) where n is the number of elements in the output tensor
92 /// - **Memory Usage**: Creates a new tensor with size equal to the output shape
93 /// - **Optimization**: Uses precomputed strides for efficient memory access
94 /// - **GradTrack Overhead**: Minimal overhead when gradient tracking is enabled
95 /// - **Memory Layout**: Output tensor is always contiguous for optimal performance
96 ///
97 /// # Implementation Details
98 ///
99 /// The index_select operation works by:
100 /// 1. Validating the dimension and index bounds
101 /// 2. Computing the output shape (same as input except along `dim`)
102 /// 3. Creating a new contiguous output tensor
103 /// 4. Iterating through all positions in the output tensor using nested loops:
104 /// - Outer loop: iterate over dimensions before `dim`
105 /// - Middle loop: iterate over the selected indices
106 /// - Inner loop: iterate over dimensions after `dim`
107 /// 5. Computing source offsets using the input tensor's strides
108 /// 6. Copying values from input to output tensor
109 /// 7. Registering the operation for gradient computation if needed
110 ///
111 /// # Safety
112 ///
113 /// This function performs comprehensive bounds checking to ensure:
114 /// - The specified dimension is within the tensor's rank
115 /// - All indices are within bounds for the specified dimension
116 /// - Memory access is safe through proper offset calculations
117 ///
118 /// # Panics
119 ///
120 /// This function will panic if:
121 /// - `dim` is greater than or equal to the tensor's rank
122 /// - Any index in `indices` is out of bounds for the specified dimension
123 ///
124 /// # Thread Safety
125 ///
126 /// This function is thread-safe and can be called concurrently on different tensors.
127 /// The operation does not modify the input tensor and creates a new output tensor.
128 #[track_caller]
129 pub fn index_select(&self, dim: usize, indices: &[usize]) -> Tensor {
130 let rank = self.shape().rank();
131 assert!(
132 dim < rank,
133 "index_select dim {} out of bounds for rank {}",
134 dim,
135 rank
136 );
137 for &idx in indices {
138 assert!(
139 idx < self.shape().dims()[dim],
140 "index {} out of bounds for dimension {} (size {})",
141 idx,
142 dim,
143 self.shape().dims()[dim]
144 );
145 }
146
147 // Output shape is same as input except along dim -> indices.len()
148 let mut out_dims = self.shape().dims().to_vec();
149 out_dims[dim] = indices.len();
150 let mut output = Tensor::new(out_dims.clone());
151
152 // Precompute strides for fast offset computation
153 let in_strides = self.strides().to_vec();
154 let out_inner: usize = out_dims[dim + 1..].iter().product();
155 let out_outer: usize = out_dims[..dim].iter().product();
156
157 unsafe {
158 let dst_ptr = output.as_mut_ptr();
159 for outer_idx in 0..out_outer {
160 // Decode outer_idx into coordinates for axes < dim
161 let mut coords = vec![0usize; rank];
162 if dim > 0 {
163 let mut tmp = outer_idx;
164 for i in (0..dim).rev() {
165 let s = self.shape().dims()[i];
166 coords[i] = tmp % s;
167 tmp /= s;
168 }
169 }
170
171 for (j, &sel) in indices.iter().enumerate() {
172 coords[dim] = sel;
173 // Iterate over inner block
174 for inner_idx in 0..out_inner {
175 // Decode inner_idx into coordinates for axes > dim
176 let mut tmp = inner_idx;
177 for (i, c) in coords.iter_mut().enumerate().take(rank).skip(dim + 1) {
178 let s = self.shape().dims()[i];
179 *c = tmp % s;
180 tmp /= s;
181 }
182
183 // Compute input offset via strides
184 let mut src_off = 0usize;
185 for i in 0..rank {
186 src_off += coords[i] * in_strides[i];
187 }
188
189 // Destination offset within output tensor (contiguous)
190 let out_block =
191 outer_idx * (indices.len() * out_inner) + j * out_inner + inner_idx;
192 *dst_ptr.add(out_block) = *self.as_ptr().add(src_off);
193 }
194 }
195 }
196 }
197
198 // GradTrack registration
199 if self.requires_grad() {
200 output.set_requires_grad(true);
201 let grad_fn = GradFn::IndexSelect {
202 dim,
203 indices: indices.to_vec(),
204 input_shape: self.shape().dims().to_vec(),
205 };
206 output.set_grad_fn(grad_fn.clone());
207 GradEngine::register_operation(output.id(), vec![self.id()], grad_fn);
208 }
209
210 output
211 }
212}
213
214#[cfg(test)]
215mod tests {
216 use super::*;
217
218 #[test]
219 fn test_index_select_basic() {
220 let x =
221 Tensor::from_slice(&(0..6).map(|i| i as f32).collect::<Vec<_>>(), vec![2, 3]).unwrap();
222 let y = x.index_select(1, &[2, 0]);
223 assert_eq!(y.shape().dims(), vec![2, 2]);
224 assert_eq!(y.get(&[0, 0]), 2.0);
225 assert_eq!(y.get(&[0, 1]), 0.0);
226 assert_eq!(y.get(&[1, 0]), 5.0);
227 assert_eq!(y.get(&[1, 1]), 3.0);
228 }
229}