train_station/tensor/indexing/
gather.rs

1use crate::gradtrack::{GradEngine, GradFn};
2use crate::tensor::core::Tensor;
3
4impl Tensor {
5    /// Gather values along a dimension using a tensor of indices
6    ///
7    /// This operation extracts elements from the input tensor based on indices provided
8    /// along a specified dimension. The output tensor has the same shape as the index
9    /// tensor, with each element taken from the input tensor at the corresponding
10    /// position with the index value substituted for the specified dimension.
11    ///
12    /// The gather operation is commonly used in machine learning for operations like
13    /// embedding lookups, attention mechanisms, and advanced indexing patterns.
14    ///
15    /// # Arguments
16    ///
17    /// * `dim` - The dimension along which to gather values (must be < tensor rank)
18    /// * `indices` - Flattened indices buffer containing the positions to gather from
19    /// * `index_shape` - Shape of the indices tensor and output tensor
20    ///
21    /// # Returns
22    ///
23    /// A new tensor with shape `index_shape` containing the gathered values
24    ///
25    /// # Examples
26    ///
27    /// ## Basic Gather Operation
28    ///
29    /// ```
30    /// use train_station::Tensor;
31    ///
32    /// // Create a 2x3 tensor: [[0.0, 0.1, 0.2], [0.3, 0.4, 0.5]]
33    /// let tensor = Tensor::from_slice(&[0.0, 0.1, 0.2, 0.3, 0.4, 0.5], vec![2, 3]).unwrap();
34    ///
35    /// // Gather along dimension 1 (columns) with indices [2, 0, 1, 1]
36    /// let indices = [2, 0, 1, 1];
37    /// let index_shape = [2, 2];
38    /// let result = tensor.gather(1, &indices, &index_shape);
39    ///
40    /// // Result shape is [2, 2]
41    /// assert_eq!(result.shape().dims(), vec![2, 2]);
42    ///
43    /// // Row 0: indices [2, 0] -> [0.2, 0.0]
44    /// assert!((result.get(&[0, 0]) - 0.2).abs() < 1e-6);
45    /// assert!((result.get(&[0, 1]) - 0.0).abs() < 1e-6);
46    ///
47    /// // Row 1: indices [1, 1] -> [0.4, 0.4]
48    /// assert!((result.get(&[1, 0]) - 0.4).abs() < 1e-6);
49    /// assert!((result.get(&[1, 1]) - 0.4).abs() < 1e-6);
50    /// ```
51    ///
52    /// ## Gather with Gradient Tracking
53    ///
54    /// ```
55    /// use train_station::Tensor;
56    ///
57    /// let tensor = Tensor::from_slice(&[0.0, 0.1, 0.2, 0.3, 0.4, 0.5], vec![2, 3]).unwrap()
58    ///     .with_requires_grad();
59    ///
60    /// let indices = [1, 1, 0, 2];
61    /// let index_shape = [2, 2];
62    /// let mut result = tensor.gather(1, &indices, &index_shape);
63    ///
64    /// // Compute gradients
65    /// result.backward(None);
66    /// let grad = tensor.grad_owned().expect("gradient missing");
67    ///
68    /// // Verify gradient accumulation for repeated indices
69    /// assert!((grad.get(&[0, 1]) - 2.0).abs() < 1e-6); // Index 1 used twice in row 0
70    /// ```
71    ///
72    /// # Performance Characteristics
73    ///
74    /// - **Time Complexity**: O(n) where n is the number of elements in the output
75    /// - **Memory Usage**: Creates a new tensor with the same size as the index tensor
76    /// - **Optimization**: Uses precomputed strides for efficient memory access
77    /// - **GradTrack Overhead**: Minimal overhead when gradient tracking is enabled
78    ///
79    /// # Implementation Details
80    ///
81    /// The gather operation works by:
82    /// 1. Validating input dimensions and index bounds
83    /// 2. Creating an output tensor with the specified index shape
84    /// 3. Iterating through all positions in the output tensor
85    /// 4. Computing source offsets using the input tensor's strides
86    /// 5. Copying values from the input tensor to the output tensor
87    /// 6. Registering the operation for gradient computation if needed
88    ///
89    /// # Safety
90    ///
91    /// This function performs bounds checking to ensure:
92    /// - The specified dimension is within the tensor's rank
93    /// - All indices are within bounds for the specified dimension
94    /// - The index shape is compatible with the input tensor shape
95    /// - The indices buffer length matches the product of index shape dimensions
96    ///
97    /// # Panics
98    ///
99    /// This function will panic if:
100    /// - `dim` is greater than or equal to the tensor's rank
101    /// - Any index in `indices` is out of bounds for the specified dimension
102    /// - The `index_shape` rank doesn't match the input tensor's rank
103    /// - The `index_shape` dimensions don't match the input tensor (except along `dim`)
104    /// - The `indices` length doesn't equal the product of `index_shape` dimensions
105    #[track_caller]
106    pub fn gather(&self, dim: usize, indices: &[usize], index_shape: &[usize]) -> Tensor {
107        let rank = self.shape().rank();
108        assert!(
109            dim < rank,
110            "gather dim {} out of bounds for rank {}",
111            dim,
112            rank
113        );
114
115        // Validate index_shape compatibility: same rank and all dims equal except along dim
116        assert_eq!(
117            index_shape.len(),
118            rank,
119            "index_shape rank mismatch: {} vs {}",
120            index_shape.len(),
121            rank
122        );
123        for (i, &s) in index_shape.iter().enumerate().take(rank) {
124            if i != dim {
125                assert_eq!(
126                    s,
127                    self.shape().dims()[i],
128                    "index_shape mismatch at dim {}: {} vs {}",
129                    i,
130                    s,
131                    self.shape().dims()[i]
132                );
133            }
134        }
135
136        let index_numel: usize = index_shape.iter().product();
137        assert_eq!(
138            indices.len(),
139            index_numel,
140            "indices length {} must equal product of index_shape {}",
141            indices.len(),
142            index_numel
143        );
144
145        // Validate indices range along dim
146        let dim_size = self.shape().dims()[dim];
147        for &idx in indices.iter() {
148            assert!(
149                idx < dim_size,
150                "gather index {} out of bounds for dim {} (size {})",
151                idx,
152                dim,
153                dim_size
154            );
155        }
156
157        // Output shape equals index_shape
158        let mut output = Tensor::new(index_shape.to_vec());
159
160        // Precompute input strides for fast offset computation
161        let in_strides = self.strides().to_vec();
162
163        // Iterate over all positions in output/index tensor
164        let rank = index_shape.len();
165        let mut coords = vec![0usize; rank];
166        for (lin, &idx) in indices.iter().enumerate().take(index_numel) {
167            // Decode linear index to multi-dimensional coords
168            let mut tmp = lin;
169            for i in (0..rank).rev() {
170                let s = index_shape[i];
171                coords[i] = tmp % s;
172                tmp /= s;
173            }
174
175            let mut src_off = 0usize;
176            for i in 0..rank {
177                let c = if i == dim { idx } else { coords[i] };
178                src_off += c * in_strides[i];
179            }
180
181            unsafe {
182                *output.as_mut_ptr().add(lin) = *self.as_ptr().add(src_off);
183            }
184        }
185
186        // GradTrack registration
187        if self.requires_grad() {
188            output.set_requires_grad(true);
189            let grad_fn = GradFn::Gather {
190                dim,
191                indices: indices.to_vec(),
192                input_shape: self.shape().dims().to_vec(),
193                index_shape: index_shape.to_vec(),
194            };
195            output.set_grad_fn(grad_fn.clone());
196            GradEngine::register_operation(output.id(), vec![self.id()], grad_fn);
197        }
198
199        output
200    }
201}
202
203#[cfg(test)]
204mod tests {
205    use super::*;
206
207    #[test]
208    fn test_gather_basic() {
209        // x shape [2,3]: [[0.0, 0.1, 0.2],[0.3,0.4,0.5]]
210        let x = Tensor::from_slice(&[0.0, 0.1, 0.2, 0.3, 0.4, 0.5], vec![2, 3]).unwrap();
211        let out = x.gather(1, &[2, 0, 1, 1], &[2, 2]);
212        assert_eq!(out.shape().dims(), vec![2, 2]);
213        // Row 0 gathered indices [2,0] -> [0.2, 0.0]
214        assert!((out.get(&[0, 0]) - 0.2).abs() < 1e-6);
215        assert!((out.get(&[0, 1]) - 0.0).abs() < 1e-6);
216        // Row 1 gathered indices [1,1] -> [0.4, 0.4]
217        assert!((out.get(&[1, 0]) - 0.4).abs() < 1e-6);
218        assert!((out.get(&[1, 1]) - 0.4).abs() < 1e-6);
219    }
220
221    #[test]
222    fn test_gather_gradients_accumulate() {
223        // x shape [2,3], gather along dim=1 with repeated indices to test accumulation
224        let x = Tensor::from_slice(&[0.0, 0.1, 0.2, 0.3, 0.4, 0.5], vec![2, 3])
225            .unwrap()
226            .with_requires_grad();
227        let mut y = x.gather(1, &[1, 1, 0, 2], &[2, 2]);
228        // Upstream gradient defaults to ones in our engine
229        y.backward(None);
230        let gx = x.grad_owned().expect("grad missing");
231        // Expected grad counts per input element:
232        // For row 0: indices [1,1] -> input[0,1] gets +2
233        // For row 1: indices [0,2] -> input[1,0] gets +1, input[1,2] gets +1
234        assert_eq!(gx.shape().dims(), vec![2, 3]);
235        // Row 0
236        assert!((gx.get(&[0, 0]) - 0.0).abs() < 1e-6);
237        assert!((gx.get(&[0, 1]) - 2.0).abs() < 1e-6);
238        assert!((gx.get(&[0, 2]) - 0.0).abs() < 1e-6);
239        // Row 1
240        assert!((gx.get(&[1, 0]) - 1.0).abs() < 1e-6);
241        assert!((gx.get(&[1, 1]) - 0.0).abs() < 1e-6);
242        assert!((gx.get(&[1, 2]) - 1.0).abs() < 1e-6);
243    }
244
245    #[test]
246    #[should_panic]
247    fn test_gather_invalid_dim() {
248        let x = Tensor::zeros(vec![2, 3]);
249        let _ = x.gather(2, &[0, 0], &[2, 1]);
250    }
251
252    #[test]
253    #[should_panic]
254    fn test_gather_index_shape_mismatch() {
255        let x = Tensor::zeros(vec![2, 3]);
256        // index_shape rank mismatch
257        let _ = x.gather(1, &[0, 0], &[2]);
258    }
259}