train_station/tensor/indexing/
gather.rs

1use crate::gradtrack::{GradEngine, GradFn};
2use crate::tensor::core::Tensor;
3
4impl Tensor {
5    /// Gather values along a dimension using a tensor of indices
6    ///
7    /// This operation extracts elements from the input tensor based on indices provided
8    /// along a specified dimension. The output tensor has the same shape as the index
9    /// tensor, with each element taken from the input tensor at the corresponding
10    /// position with the index value substituted for the specified dimension.
11    ///
12    /// The gather operation is commonly used in machine learning for operations like
13    /// embedding lookups, attention mechanisms, and advanced indexing patterns.
14    ///
15    /// # Arguments
16    ///
17    /// * `dim` - The dimension along which to gather values (must be < tensor rank)
18    /// * `indices` - Flattened indices buffer containing the positions to gather from
19    /// * `index_shape` - Shape of the indices tensor and output tensor
20    ///
21    /// # Returns
22    ///
23    /// A new tensor with shape `index_shape` containing the gathered values
24    ///
25    /// # Examples
26    ///
27    /// ## Basic Gather Operation
28    ///
29    /// ```
30    /// use train_station::Tensor;
31    ///
32    /// // Create a 2x3 tensor: [[0.0, 0.1, 0.2], [0.3, 0.4, 0.5]]
33    /// let tensor = Tensor::from_slice(&[0.0, 0.1, 0.2, 0.3, 0.4, 0.5], vec![2, 3]).unwrap();
34    ///
35    /// // Gather along dimension 1 (columns) with indices [2, 0, 1, 1]
36    /// let indices = [2, 0, 1, 1];
37    /// let index_shape = [2, 2];
38    /// let result = tensor.gather(1, &indices, &index_shape);
39    ///
40    /// // Result shape is [2, 2]
41    /// assert_eq!(result.shape().dims, vec![2, 2]);
42    ///
43    /// // Row 0: indices [2, 0] -> [0.2, 0.0]
44    /// assert!((result.get(&[0, 0]) - 0.2).abs() < 1e-6);
45    /// assert!((result.get(&[0, 1]) - 0.0).abs() < 1e-6);
46    ///
47    /// // Row 1: indices [1, 1] -> [0.4, 0.4]
48    /// assert!((result.get(&[1, 0]) - 0.4).abs() < 1e-6);
49    /// assert!((result.get(&[1, 1]) - 0.4).abs() < 1e-6);
50    /// ```
51    ///
52    /// ## Gather with Gradient Tracking
53    ///
54    /// ```
55    /// use train_station::Tensor;
56    ///
57    /// let tensor = Tensor::from_slice(&[0.0, 0.1, 0.2, 0.3, 0.4, 0.5], vec![2, 3]).unwrap()
58    ///     .with_requires_grad();
59    ///
60    /// let indices = [1, 1, 0, 2];
61    /// let index_shape = [2, 2];
62    /// let mut result = tensor.gather(1, &indices, &index_shape);
63    ///
64    /// // Compute gradients
65    /// result.backward(None);
66    /// let grad = tensor.grad_by_value().expect("gradient missing");
67    ///
68    /// // Verify gradient accumulation for repeated indices
69    /// assert!((grad.get(&[0, 1]) - 2.0).abs() < 1e-6); // Index 1 used twice in row 0
70    /// ```
71    ///
72    /// # Performance Characteristics
73    ///
74    /// - **Time Complexity**: O(n) where n is the number of elements in the output
75    /// - **Memory Usage**: Creates a new tensor with the same size as the index tensor
76    /// - **Optimization**: Uses precomputed strides for efficient memory access
77    /// - **GradTrack Overhead**: Minimal overhead when gradient tracking is enabled
78    ///
79    /// # Implementation Details
80    ///
81    /// The gather operation works by:
82    /// 1. Validating input dimensions and index bounds
83    /// 2. Creating an output tensor with the specified index shape
84    /// 3. Iterating through all positions in the output tensor
85    /// 4. Computing source offsets using the input tensor's strides
86    /// 5. Copying values from the input tensor to the output tensor
87    /// 6. Registering the operation for gradient computation if needed
88    ///
89    /// # Safety
90    ///
91    /// This function performs bounds checking to ensure:
92    /// - The specified dimension is within the tensor's rank
93    /// - All indices are within bounds for the specified dimension
94    /// - The index shape is compatible with the input tensor shape
95    /// - The indices buffer length matches the product of index shape dimensions
96    ///
97    /// # Panics
98    ///
99    /// This function will panic if:
100    /// - `dim` is greater than or equal to the tensor's rank
101    /// - Any index in `indices` is out of bounds for the specified dimension
102    /// - The `index_shape` rank doesn't match the input tensor's rank
103    /// - The `index_shape` dimensions don't match the input tensor (except along `dim`)
104    /// - The `indices` length doesn't equal the product of `index_shape` dimensions
105    pub fn gather(&self, dim: usize, indices: &[usize], index_shape: &[usize]) -> Tensor {
106        let rank = self.shape().rank();
107        assert!(
108            dim < rank,
109            "gather dim {} out of bounds for rank {}",
110            dim,
111            rank
112        );
113
114        // Validate index_shape compatibility: same rank and all dims equal except along dim
115        assert_eq!(
116            index_shape.len(),
117            rank,
118            "index_shape rank mismatch: {} vs {}",
119            index_shape.len(),
120            rank
121        );
122        for (i, &s) in index_shape.iter().enumerate().take(rank) {
123            if i != dim {
124                assert_eq!(
125                    s,
126                    self.shape().dims[i],
127                    "index_shape mismatch at dim {}: {} vs {}",
128                    i,
129                    s,
130                    self.shape().dims[i]
131                );
132            }
133        }
134
135        let index_numel: usize = index_shape.iter().product();
136        assert_eq!(
137            indices.len(),
138            index_numel,
139            "indices length {} must equal product of index_shape {}",
140            indices.len(),
141            index_numel
142        );
143
144        // Validate indices range along dim
145        let dim_size = self.shape().dims[dim];
146        for &idx in indices.iter() {
147            assert!(
148                idx < dim_size,
149                "gather index {} out of bounds for dim {} (size {})",
150                idx,
151                dim,
152                dim_size
153            );
154        }
155
156        // Output shape equals index_shape
157        let mut output = Tensor::new(index_shape.to_vec());
158
159        // Precompute input strides for fast offset computation
160        let in_strides = self.strides().to_vec();
161
162        // Iterate over all positions in output/index tensor
163        let rank = index_shape.len();
164        let mut coords = vec![0usize; rank];
165        for (lin, &idx) in indices.iter().enumerate().take(index_numel) {
166            // Decode linear index to multi-dimensional coords
167            let mut tmp = lin;
168            for i in (0..rank).rev() {
169                let s = index_shape[i];
170                coords[i] = tmp % s;
171                tmp /= s;
172            }
173
174            let mut src_off = 0usize;
175            for i in 0..rank {
176                let c = if i == dim { idx } else { coords[i] };
177                src_off += c * in_strides[i];
178            }
179
180            unsafe {
181                *output.as_mut_ptr().add(lin) = *self.as_ptr().add(src_off);
182            }
183        }
184
185        // GradTrack registration
186        if self.requires_grad() {
187            output.set_requires_grad(true);
188            let grad_fn = GradFn::Gather {
189                dim,
190                indices: indices.to_vec(),
191                input_shape: self.shape().dims.clone(),
192                index_shape: index_shape.to_vec(),
193            };
194            output.set_grad_fn(grad_fn.clone());
195            GradEngine::register_operation(output.id(), vec![self.id()], grad_fn);
196        }
197
198        output
199    }
200}
201
202#[cfg(test)]
203mod tests {
204    use super::*;
205
206    #[test]
207    fn test_gather_basic() {
208        // x shape [2,3]: [[0.0, 0.1, 0.2],[0.3,0.4,0.5]]
209        let x = Tensor::from_slice(&[0.0, 0.1, 0.2, 0.3, 0.4, 0.5], vec![2, 3]).unwrap();
210        let out = x.gather(1, &[2, 0, 1, 1], &[2, 2]);
211        assert_eq!(out.shape().dims, vec![2, 2]);
212        // Row 0 gathered indices [2,0] -> [0.2, 0.0]
213        assert!((out.get(&[0, 0]) - 0.2).abs() < 1e-6);
214        assert!((out.get(&[0, 1]) - 0.0).abs() < 1e-6);
215        // Row 1 gathered indices [1,1] -> [0.4, 0.4]
216        assert!((out.get(&[1, 0]) - 0.4).abs() < 1e-6);
217        assert!((out.get(&[1, 1]) - 0.4).abs() < 1e-6);
218    }
219
220    #[test]
221    fn test_gather_gradients_accumulate() {
222        // x shape [2,3], gather along dim=1 with repeated indices to test accumulation
223        let x = Tensor::from_slice(&[0.0, 0.1, 0.2, 0.3, 0.4, 0.5], vec![2, 3])
224            .unwrap()
225            .with_requires_grad();
226        let mut y = x.gather(1, &[1, 1, 0, 2], &[2, 2]);
227        // Upstream gradient defaults to ones in our engine
228        y.backward(None);
229        let gx = x.grad_by_value().expect("grad missing");
230        // Expected grad counts per input element:
231        // For row 0: indices [1,1] -> input[0,1] gets +2
232        // For row 1: indices [0,2] -> input[1,0] gets +1, input[1,2] gets +1
233        assert_eq!(gx.shape().dims, vec![2, 3]);
234        // Row 0
235        assert!((gx.get(&[0, 0]) - 0.0).abs() < 1e-6);
236        assert!((gx.get(&[0, 1]) - 2.0).abs() < 1e-6);
237        assert!((gx.get(&[0, 2]) - 0.0).abs() < 1e-6);
238        // Row 1
239        assert!((gx.get(&[1, 0]) - 1.0).abs() < 1e-6);
240        assert!((gx.get(&[1, 1]) - 0.0).abs() < 1e-6);
241        assert!((gx.get(&[1, 2]) - 1.0).abs() < 1e-6);
242    }
243
244    #[test]
245    #[should_panic]
246    fn test_gather_invalid_dim() {
247        let x = Tensor::zeros(vec![2, 3]);
248        let _ = x.gather(2, &[0, 0], &[2, 1]);
249    }
250
251    #[test]
252    #[should_panic]
253    fn test_gather_index_shape_mismatch() {
254        let x = Tensor::zeros(vec![2, 3]);
255        // index_shape rank mismatch
256        let _ = x.gather(1, &[0, 0], &[2]);
257    }
258}