train_station/tensor/indexing/gather.rs
1use crate::gradtrack::{is_grad_enabled, GradEngine, GradFn};
2use crate::tensor::core::Tensor;
3
4impl Tensor {
5 /// Gather values along a dimension using a tensor of indices
6 ///
7 /// This operation extracts elements from the input tensor based on indices provided
8 /// along a specified dimension. The output tensor has the same shape as the index
9 /// tensor, with each element taken from the input tensor at the corresponding
10 /// position with the index value substituted for the specified dimension.
11 ///
12 /// The gather operation is commonly used in machine learning for operations like
13 /// embedding lookups, attention mechanisms, and advanced indexing patterns.
14 ///
15 /// # Arguments
16 ///
17 /// * `dim` - The dimension along which to gather values (must be < tensor rank)
18 /// * `indices` - Flattened indices buffer containing the positions to gather from
19 /// * `index_shape` - Shape of the indices tensor and output tensor
20 ///
21 /// # Returns
22 ///
23 /// A new tensor with shape `index_shape` containing the gathered values
24 ///
25 /// # Examples
26 ///
27 /// ## Basic Gather Operation
28 ///
29 /// ```
30 /// use train_station::Tensor;
31 ///
32 /// // Create a 2x3 tensor: [[0.0, 0.1, 0.2], [0.3, 0.4, 0.5]]
33 /// let tensor = Tensor::from_slice(&[0.0, 0.1, 0.2, 0.3, 0.4, 0.5], vec![2, 3]).unwrap();
34 ///
35 /// // Gather along dimension 1 (columns) with indices [2, 0, 1, 1]
36 /// let indices = [2, 0, 1, 1];
37 /// let index_shape = [2, 2];
38 /// let result = tensor.gather(1, &indices, &index_shape);
39 ///
40 /// // Result shape is [2, 2]
41 /// assert_eq!(result.shape().dims(), vec![2, 2]);
42 ///
43 /// // Row 0: indices [2, 0] -> [0.2, 0.0]
44 /// assert!((result.get(&[0, 0]) - 0.2).abs() < 1e-6);
45 /// assert!((result.get(&[0, 1]) - 0.0).abs() < 1e-6);
46 ///
47 /// // Row 1: indices [1, 1] -> [0.4, 0.4]
48 /// assert!((result.get(&[1, 0]) - 0.4).abs() < 1e-6);
49 /// assert!((result.get(&[1, 1]) - 0.4).abs() < 1e-6);
50 /// ```
51 ///
52 /// ## Gather with Gradient Tracking
53 ///
54 /// ```
55 /// use train_station::Tensor;
56 ///
57 /// let tensor = Tensor::from_slice(&[0.0, 0.1, 0.2, 0.3, 0.4, 0.5], vec![2, 3]).unwrap()
58 /// .with_requires_grad();
59 ///
60 /// let indices = [1, 1, 0, 2];
61 /// let index_shape = [2, 2];
62 /// let mut result = tensor.gather(1, &indices, &index_shape);
63 ///
64 /// // Compute gradients
65 /// result.backward(None);
66 /// let grad = tensor.grad_owned().expect("gradient missing");
67 ///
68 /// // Verify gradient accumulation for repeated indices
69 /// assert!((grad.get(&[0, 1]) - 2.0).abs() < 1e-6); // Index 1 used twice in row 0
70 /// ```
71 ///
72 /// # Performance Characteristics
73 ///
74 /// - **Time Complexity**: O(n) where n is the number of elements in the output
75 /// - **Memory Usage**: Creates a new tensor with the same size as the index tensor
76 /// - **Optimization**: Uses precomputed strides for efficient memory access
77 /// - **GradTrack Overhead**: Minimal overhead when gradient tracking is enabled
78 ///
79 /// # Implementation Details
80 ///
81 /// The gather operation works by:
82 /// 1. Validating input dimensions and index bounds
83 /// 2. Creating an output tensor with the specified index shape
84 /// 3. Iterating through all positions in the output tensor
85 /// 4. Computing source offsets using the input tensor's strides
86 /// 5. Copying values from the input tensor to the output tensor
87 /// 6. Registering the operation for gradient computation if needed
88 ///
89 /// # Safety
90 ///
91 /// This function performs bounds checking to ensure:
92 /// - The specified dimension is within the tensor's rank
93 /// - All indices are within bounds for the specified dimension
94 /// - The index shape is compatible with the input tensor shape
95 /// - The indices buffer length matches the product of index shape dimensions
96 ///
97 /// # Panics
98 ///
99 /// This function will panic if:
100 /// - `dim` is greater than or equal to the tensor's rank
101 /// - Any index in `indices` is out of bounds for the specified dimension
102 /// - The `index_shape` rank doesn't match the input tensor's rank
103 /// - The `index_shape` dimensions don't match the input tensor (except along `dim`)
104 /// - The `indices` length doesn't equal the product of `index_shape` dimensions
105 #[track_caller]
106 pub fn gather(&self, dim: usize, indices: &[usize], index_shape: &[usize]) -> Tensor {
107 let rank = self.shape().rank();
108 assert!(
109 dim < rank,
110 "gather dim {} out of bounds for rank {}",
111 dim,
112 rank
113 );
114
115 // Validate index_shape compatibility: same rank and all dims equal except along dim
116 assert_eq!(
117 index_shape.len(),
118 rank,
119 "index_shape rank mismatch: {} vs {}",
120 index_shape.len(),
121 rank
122 );
123 for (i, &s) in index_shape.iter().enumerate().take(rank) {
124 if i != dim {
125 assert_eq!(
126 s,
127 self.shape().dims()[i],
128 "index_shape mismatch at dim {}: {} vs {}",
129 i,
130 s,
131 self.shape().dims()[i]
132 );
133 }
134 }
135
136 let index_numel: usize = index_shape.iter().product();
137 assert_eq!(
138 indices.len(),
139 index_numel,
140 "indices length {} must equal product of index_shape {}",
141 indices.len(),
142 index_numel
143 );
144
145 // Validate indices range along dim
146 let dim_size = self.shape().dims()[dim];
147 for &idx in indices.iter() {
148 assert!(
149 idx < dim_size,
150 "gather index {} out of bounds for dim {} (size {})",
151 idx,
152 dim,
153 dim_size
154 );
155 }
156
157 // Output shape equals index_shape
158 let mut output = Tensor::new(index_shape.to_vec());
159
160 // Precompute input strides for fast offset computation
161 let in_strides = self.strides().to_vec();
162
163 // Iterate over all positions in output/index tensor
164 let rank = index_shape.len();
165 let mut coords = vec![0usize; rank];
166 for (lin, &idx) in indices.iter().enumerate().take(index_numel) {
167 // Decode linear index to multi-dimensional coords
168 let mut tmp = lin;
169 for i in (0..rank).rev() {
170 let s = index_shape[i];
171 coords[i] = tmp % s;
172 tmp /= s;
173 }
174
175 let mut src_off = 0usize;
176 for i in 0..rank {
177 let c = if i == dim { idx } else { coords[i] };
178 src_off += c * in_strides[i];
179 }
180
181 unsafe {
182 *output.as_mut_ptr().add(lin) = *self.as_ptr().add(src_off);
183 }
184 }
185
186 // GradTrack registration
187 if self.requires_grad() && is_grad_enabled() {
188 output.set_requires_grad(true);
189 let grad_fn = GradFn::Gather {
190 dim,
191 indices: indices.to_vec(),
192 input_shape: self.shape().dims().to_vec(),
193 index_shape: index_shape.to_vec(),
194 };
195 output.set_grad_fn(grad_fn.clone());
196 GradEngine::register_operation(output.id(), vec![self.id()], grad_fn);
197 }
198
199 output
200 }
201}
202
203#[cfg(test)]
204mod tests {
205 use super::*;
206
207 #[test]
208 fn test_gather_basic() {
209 // x shape [2,3]: [[0.0, 0.1, 0.2],[0.3,0.4,0.5]]
210 let x = Tensor::from_slice(&[0.0, 0.1, 0.2, 0.3, 0.4, 0.5], vec![2, 3]).unwrap();
211 let out = x.gather(1, &[2, 0, 1, 1], &[2, 2]);
212 assert_eq!(out.shape().dims(), vec![2, 2]);
213 // Row 0 gathered indices [2,0] -> [0.2, 0.0]
214 assert!((out.get(&[0, 0]) - 0.2).abs() < 1e-6);
215 assert!((out.get(&[0, 1]) - 0.0).abs() < 1e-6);
216 // Row 1 gathered indices [1,1] -> [0.4, 0.4]
217 assert!((out.get(&[1, 0]) - 0.4).abs() < 1e-6);
218 assert!((out.get(&[1, 1]) - 0.4).abs() < 1e-6);
219 }
220
221 #[test]
222 fn test_gather_gradients_accumulate() {
223 // x shape [2,3], gather along dim=1 with repeated indices to test accumulation
224 let x = Tensor::from_slice(&[0.0, 0.1, 0.2, 0.3, 0.4, 0.5], vec![2, 3])
225 .unwrap()
226 .with_requires_grad();
227 let mut y = x.gather(1, &[1, 1, 0, 2], &[2, 2]);
228 // Upstream gradient defaults to ones in our engine
229 y.backward(None);
230 let gx = x.grad_owned().expect("grad missing");
231 // Expected grad counts per input element:
232 // For row 0: indices [1,1] -> input[0,1] gets +2
233 // For row 1: indices [0,2] -> input[1,0] gets +1, input[1,2] gets +1
234 assert_eq!(gx.shape().dims(), vec![2, 3]);
235 // Row 0
236 assert!((gx.get(&[0, 0]) - 0.0).abs() < 1e-6);
237 assert!((gx.get(&[0, 1]) - 2.0).abs() < 1e-6);
238 assert!((gx.get(&[0, 2]) - 0.0).abs() < 1e-6);
239 // Row 1
240 assert!((gx.get(&[1, 0]) - 1.0).abs() < 1e-6);
241 assert!((gx.get(&[1, 1]) - 0.0).abs() < 1e-6);
242 assert!((gx.get(&[1, 2]) - 1.0).abs() < 1e-6);
243 }
244
245 #[test]
246 #[should_panic]
247 fn test_gather_invalid_dim() {
248 let x = Tensor::zeros(vec![2, 3]);
249 let _ = x.gather(2, &[0, 0], &[2, 1]);
250 }
251
252 #[test]
253 #[should_panic]
254 fn test_gather_index_shape_mismatch() {
255 let x = Tensor::zeros(vec![2, 3]);
256 // index_shape rank mismatch
257 let _ = x.gather(1, &[0, 0], &[2]);
258 }
259}