train_station/tensor/indexing/gather.rs
1use crate::gradtrack::{GradEngine, GradFn};
2use crate::tensor::core::Tensor;
3
4impl Tensor {
5 /// Gather values along a dimension using a tensor of indices
6 ///
7 /// This operation extracts elements from the input tensor based on indices provided
8 /// along a specified dimension. The output tensor has the same shape as the index
9 /// tensor, with each element taken from the input tensor at the corresponding
10 /// position with the index value substituted for the specified dimension.
11 ///
12 /// The gather operation is commonly used in machine learning for operations like
13 /// embedding lookups, attention mechanisms, and advanced indexing patterns.
14 ///
15 /// # Arguments
16 ///
17 /// * `dim` - The dimension along which to gather values (must be < tensor rank)
18 /// * `indices` - Flattened indices buffer containing the positions to gather from
19 /// * `index_shape` - Shape of the indices tensor and output tensor
20 ///
21 /// # Returns
22 ///
23 /// A new tensor with shape `index_shape` containing the gathered values
24 ///
25 /// # Examples
26 ///
27 /// ## Basic Gather Operation
28 ///
29 /// ```
30 /// use train_station::Tensor;
31 ///
32 /// // Create a 2x3 tensor: [[0.0, 0.1, 0.2], [0.3, 0.4, 0.5]]
33 /// let tensor = Tensor::from_slice(&[0.0, 0.1, 0.2, 0.3, 0.4, 0.5], vec![2, 3]).unwrap();
34 ///
35 /// // Gather along dimension 1 (columns) with indices [2, 0, 1, 1]
36 /// let indices = [2, 0, 1, 1];
37 /// let index_shape = [2, 2];
38 /// let result = tensor.gather(1, &indices, &index_shape);
39 ///
40 /// // Result shape is [2, 2]
41 /// assert_eq!(result.shape().dims, vec![2, 2]);
42 ///
43 /// // Row 0: indices [2, 0] -> [0.2, 0.0]
44 /// assert!((result.get(&[0, 0]) - 0.2).abs() < 1e-6);
45 /// assert!((result.get(&[0, 1]) - 0.0).abs() < 1e-6);
46 ///
47 /// // Row 1: indices [1, 1] -> [0.4, 0.4]
48 /// assert!((result.get(&[1, 0]) - 0.4).abs() < 1e-6);
49 /// assert!((result.get(&[1, 1]) - 0.4).abs() < 1e-6);
50 /// ```
51 ///
52 /// ## Gather with Gradient Tracking
53 ///
54 /// ```
55 /// use train_station::Tensor;
56 ///
57 /// let tensor = Tensor::from_slice(&[0.0, 0.1, 0.2, 0.3, 0.4, 0.5], vec![2, 3]).unwrap()
58 /// .with_requires_grad();
59 ///
60 /// let indices = [1, 1, 0, 2];
61 /// let index_shape = [2, 2];
62 /// let mut result = tensor.gather(1, &indices, &index_shape);
63 ///
64 /// // Compute gradients
65 /// result.backward(None);
66 /// let grad = tensor.grad_by_value().expect("gradient missing");
67 ///
68 /// // Verify gradient accumulation for repeated indices
69 /// assert!((grad.get(&[0, 1]) - 2.0).abs() < 1e-6); // Index 1 used twice in row 0
70 /// ```
71 ///
72 /// # Performance Characteristics
73 ///
74 /// - **Time Complexity**: O(n) where n is the number of elements in the output
75 /// - **Memory Usage**: Creates a new tensor with the same size as the index tensor
76 /// - **Optimization**: Uses precomputed strides for efficient memory access
77 /// - **GradTrack Overhead**: Minimal overhead when gradient tracking is enabled
78 ///
79 /// # Implementation Details
80 ///
81 /// The gather operation works by:
82 /// 1. Validating input dimensions and index bounds
83 /// 2. Creating an output tensor with the specified index shape
84 /// 3. Iterating through all positions in the output tensor
85 /// 4. Computing source offsets using the input tensor's strides
86 /// 5. Copying values from the input tensor to the output tensor
87 /// 6. Registering the operation for gradient computation if needed
88 ///
89 /// # Safety
90 ///
91 /// This function performs bounds checking to ensure:
92 /// - The specified dimension is within the tensor's rank
93 /// - All indices are within bounds for the specified dimension
94 /// - The index shape is compatible with the input tensor shape
95 /// - The indices buffer length matches the product of index shape dimensions
96 ///
97 /// # Panics
98 ///
99 /// This function will panic if:
100 /// - `dim` is greater than or equal to the tensor's rank
101 /// - Any index in `indices` is out of bounds for the specified dimension
102 /// - The `index_shape` rank doesn't match the input tensor's rank
103 /// - The `index_shape` dimensions don't match the input tensor (except along `dim`)
104 /// - The `indices` length doesn't equal the product of `index_shape` dimensions
105 pub fn gather(&self, dim: usize, indices: &[usize], index_shape: &[usize]) -> Tensor {
106 let rank = self.shape().rank();
107 assert!(
108 dim < rank,
109 "gather dim {} out of bounds for rank {}",
110 dim,
111 rank
112 );
113
114 // Validate index_shape compatibility: same rank and all dims equal except along dim
115 assert_eq!(
116 index_shape.len(),
117 rank,
118 "index_shape rank mismatch: {} vs {}",
119 index_shape.len(),
120 rank
121 );
122 for (i, &s) in index_shape.iter().enumerate().take(rank) {
123 if i != dim {
124 assert_eq!(
125 s,
126 self.shape().dims[i],
127 "index_shape mismatch at dim {}: {} vs {}",
128 i,
129 s,
130 self.shape().dims[i]
131 );
132 }
133 }
134
135 let index_numel: usize = index_shape.iter().product();
136 assert_eq!(
137 indices.len(),
138 index_numel,
139 "indices length {} must equal product of index_shape {}",
140 indices.len(),
141 index_numel
142 );
143
144 // Validate indices range along dim
145 let dim_size = self.shape().dims[dim];
146 for &idx in indices.iter() {
147 assert!(
148 idx < dim_size,
149 "gather index {} out of bounds for dim {} (size {})",
150 idx,
151 dim,
152 dim_size
153 );
154 }
155
156 // Output shape equals index_shape
157 let mut output = Tensor::new(index_shape.to_vec());
158
159 // Precompute input strides for fast offset computation
160 let in_strides = self.strides().to_vec();
161
162 // Iterate over all positions in output/index tensor
163 let rank = index_shape.len();
164 let mut coords = vec![0usize; rank];
165 for (lin, &idx) in indices.iter().enumerate().take(index_numel) {
166 // Decode linear index to multi-dimensional coords
167 let mut tmp = lin;
168 for i in (0..rank).rev() {
169 let s = index_shape[i];
170 coords[i] = tmp % s;
171 tmp /= s;
172 }
173
174 let mut src_off = 0usize;
175 for i in 0..rank {
176 let c = if i == dim { idx } else { coords[i] };
177 src_off += c * in_strides[i];
178 }
179
180 unsafe {
181 *output.as_mut_ptr().add(lin) = *self.as_ptr().add(src_off);
182 }
183 }
184
185 // GradTrack registration
186 if self.requires_grad() {
187 output.set_requires_grad(true);
188 let grad_fn = GradFn::Gather {
189 dim,
190 indices: indices.to_vec(),
191 input_shape: self.shape().dims.clone(),
192 index_shape: index_shape.to_vec(),
193 };
194 output.set_grad_fn(grad_fn.clone());
195 GradEngine::register_operation(output.id(), vec![self.id()], grad_fn);
196 }
197
198 output
199 }
200}
201
202#[cfg(test)]
203mod tests {
204 use super::*;
205
206 #[test]
207 fn test_gather_basic() {
208 // x shape [2,3]: [[0.0, 0.1, 0.2],[0.3,0.4,0.5]]
209 let x = Tensor::from_slice(&[0.0, 0.1, 0.2, 0.3, 0.4, 0.5], vec![2, 3]).unwrap();
210 let out = x.gather(1, &[2, 0, 1, 1], &[2, 2]);
211 assert_eq!(out.shape().dims, vec![2, 2]);
212 // Row 0 gathered indices [2,0] -> [0.2, 0.0]
213 assert!((out.get(&[0, 0]) - 0.2).abs() < 1e-6);
214 assert!((out.get(&[0, 1]) - 0.0).abs() < 1e-6);
215 // Row 1 gathered indices [1,1] -> [0.4, 0.4]
216 assert!((out.get(&[1, 0]) - 0.4).abs() < 1e-6);
217 assert!((out.get(&[1, 1]) - 0.4).abs() < 1e-6);
218 }
219
220 #[test]
221 fn test_gather_gradients_accumulate() {
222 // x shape [2,3], gather along dim=1 with repeated indices to test accumulation
223 let x = Tensor::from_slice(&[0.0, 0.1, 0.2, 0.3, 0.4, 0.5], vec![2, 3])
224 .unwrap()
225 .with_requires_grad();
226 let mut y = x.gather(1, &[1, 1, 0, 2], &[2, 2]);
227 // Upstream gradient defaults to ones in our engine
228 y.backward(None);
229 let gx = x.grad_by_value().expect("grad missing");
230 // Expected grad counts per input element:
231 // For row 0: indices [1,1] -> input[0,1] gets +2
232 // For row 1: indices [0,2] -> input[1,0] gets +1, input[1,2] gets +1
233 assert_eq!(gx.shape().dims, vec![2, 3]);
234 // Row 0
235 assert!((gx.get(&[0, 0]) - 0.0).abs() < 1e-6);
236 assert!((gx.get(&[0, 1]) - 2.0).abs() < 1e-6);
237 assert!((gx.get(&[0, 2]) - 0.0).abs() < 1e-6);
238 // Row 1
239 assert!((gx.get(&[1, 0]) - 1.0).abs() < 1e-6);
240 assert!((gx.get(&[1, 1]) - 0.0).abs() < 1e-6);
241 assert!((gx.get(&[1, 2]) - 1.0).abs() < 1e-6);
242 }
243
244 #[test]
245 #[should_panic]
246 fn test_gather_invalid_dim() {
247 let x = Tensor::zeros(vec![2, 3]);
248 let _ = x.gather(2, &[0, 0], &[2, 1]);
249 }
250
251 #[test]
252 #[should_panic]
253 fn test_gather_index_shape_mismatch() {
254 let x = Tensor::zeros(vec![2, 3]);
255 // index_shape rank mismatch
256 let _ = x.gather(1, &[0, 0], &[2]);
257 }
258}