train_station/tensor/indexing/
select.rs

1use crate::gradtrack::{GradEngine, GradFn};
2use crate::tensor::core::Tensor;
3
4impl Tensor {
5    /// Select a slice along a given dimension at a specific index
6    ///
7    /// This operation extracts a slice from the input tensor by fixing a specific dimension
8    /// at a given index. The result is a tensor with one fewer dimension than the input,
9    /// containing the selected slice.
10    ///
11    /// The select operation returns a view (zero-copy) when the base offset is zero,
12    /// otherwise it creates a contiguous copy to ensure correctness. This operation is
13    /// commonly used for extracting specific rows, columns, or slices from tensors.
14    ///
15    /// # Arguments
16    ///
17    /// * `dim` - The dimension along which to select (must be < tensor rank)
18    /// * `index` - The index within the specified dimension to select (must be < dim size)
19    ///
20    /// # Returns
21    ///
22    /// A tensor with the selected slice. The result has the same shape as the input
23    /// except with the specified dimension removed.
24    ///
25    /// # Examples
26    ///
27    /// ## Basic Row Selection
28    ///
29    /// ```
30    /// use train_station::Tensor;
31    ///
32    /// // Create a 2x3 tensor: [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]
33    /// let tensor = Tensor::from_slice(&[0.0, 1.0, 2.0, 3.0, 4.0, 5.0], vec![2, 3]).unwrap();
34    ///
35    /// // Select row 1 (dimension 0, index 1)
36    /// let result = tensor.select(0, 1);
37    ///
38    /// // Result shape is [3] (dimension 0 removed)
39    /// assert_eq!(result.shape().dims, vec![3]);
40    /// assert_eq!(result.get(&[0]), 3.0);  // First element of row 1
41    /// assert_eq!(result.get(&[1]), 4.0);  // Second element of row 1
42    /// assert_eq!(result.get(&[2]), 5.0);  // Third element of row 1
43    /// ```
44    ///
45    /// ## Column Selection
46    ///
47    /// ```
48    /// use train_station::Tensor;
49    ///
50    /// // Create a 2x3 tensor: [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]
51    /// let tensor = Tensor::from_slice(&[0.0, 1.0, 2.0, 3.0, 4.0, 5.0], vec![2, 3]).unwrap();
52    ///
53    /// // Select column 1 (dimension 1, index 1)
54    /// let result = tensor.select(1, 1);
55    ///
56    /// // Result shape is [2] (dimension 1 removed)
57    /// assert_eq!(result.shape().dims, vec![2]);
58    /// assert_eq!(result.get(&[0]), 1.0);  // Column 1, row 0
59    /// assert_eq!(result.get(&[1]), 4.0);  // Column 1, row 1
60    /// ```
61    ///
62    /// ## Select with Gradient Tracking
63    ///
64    /// ```
65    /// use train_station::Tensor;
66    ///
67    /// let tensor = Tensor::from_slice(&[0.0, 1.0, 2.0, 3.0], vec![2, 2]).unwrap()
68    ///     .with_requires_grad();
69    ///
70    /// // Select row 1 with gradient tracking enabled
71    /// let mut result = tensor.select(0, 1);
72    /// result.backward(None);
73    ///
74    /// // Verify gradients are computed correctly
75    /// let grad = tensor.grad_by_value().expect("gradient missing");
76    /// assert_eq!(grad.shape().dims, vec![2, 2]);
77    /// // Only row 1 receives gradients
78    /// assert_eq!(grad.get(&[0, 0]), 0.0);  // Row 0: no gradient
79    /// assert_eq!(grad.get(&[0, 1]), 0.0);  // Row 0: no gradient
80    /// assert_eq!(grad.get(&[1, 0]), 1.0);  // Row 1: gradient flows
81    /// assert_eq!(grad.get(&[1, 1]), 1.0);  // Row 1: gradient flows
82    /// ```
83    ///
84    /// # Performance Characteristics
85    ///
86    /// - **Time Complexity**: O(n) where n is the number of elements in the selected slice
87    /// - **Memory Usage**: Zero-copy view when base offset is zero, otherwise creates a copy
88    /// - **Optimization**: Uses efficient stride-based access for non-contiguous tensors
89    /// - **GradTrack Overhead**: Minimal overhead when gradient tracking is enabled
90    /// - **Memory Layout**: Result is contiguous when a copy is made, view otherwise
91    ///
92    /// # Implementation Details
93    ///
94    /// The select operation works by:
95    /// 1. Validating the dimension and index bounds
96    /// 2. Computing the new shape by removing the selected dimension
97    /// 3. Computing the new strides by removing the selected dimension's stride
98    /// 4. Calculating the base offset for the selected slice
99    /// 5. If base offset is zero: creating a view with adjusted shape/strides
100    /// 6. If base offset is non-zero: creating a contiguous copy of the slice
101    /// 7. Registering the operation for gradient computation if needed
102    ///
103    /// # Safety
104    ///
105    /// This function performs comprehensive bounds checking to ensure:
106    /// - The tensor has non-zero rank
107    /// - The specified dimension is within the tensor's rank
108    /// - The index is within bounds for the specified dimension
109    /// - Memory access is safe through proper offset calculations
110    ///
111    /// # Panics
112    ///
113    /// This function will panic if:
114    /// - The tensor has zero rank
115    /// - `dim` is greater than or equal to the tensor's rank
116    /// - `index` is greater than or equal to the size of the specified dimension
117    ///
118    /// # Thread Safety
119    ///
120    /// This function is thread-safe and can be called concurrently on different tensors.
121    /// The operation does not modify the input tensor and creates either a view or a new tensor.
122    ///
123    /// # View vs Copy Behavior
124    ///
125    /// - **View (zero-copy)**: When the base offset is zero, returns a view that shares
126    ///   the same memory as the input tensor with adjusted shape and strides
127    /// - **Copy**: When the base offset is non-zero, creates a contiguous copy to ensure
128    ///   correctness across all operations
129    ///
130    /// # GradTrack Behavior
131    ///
132    /// When gradient tracking is enabled:
133    /// - Gradients are scattered back to the selected slice in the input tensor
134    /// - Other positions in the input tensor receive zero gradients
135    /// - This behavior ensures correct gradient flow for the selected elements
136    pub fn select(&self, dim: usize, index: usize) -> Tensor {
137        let rank = self.shape().rank();
138        assert!(rank > 0, "select requires non-zero rank");
139        assert!(
140            dim < rank,
141            "select dim {} out of bounds for rank {}",
142            dim,
143            rank
144        );
145        let dim_size = self.shape().dims[dim];
146        assert!(
147            index < dim_size,
148            "select index {} out of bounds for dimension {} (size {})",
149            index,
150            dim,
151            dim_size
152        );
153
154        // Build new dims/strides removing the selected dimension
155        let mut new_dims = Vec::with_capacity(rank - 1);
156        let mut new_strides = Vec::with_capacity(rank - 1);
157        for i in 0..rank {
158            if i == dim {
159                continue;
160            }
161            new_dims.push(self.shape().dims[i]);
162            new_strides.push(self.strides()[i]);
163        }
164
165        // Base pointer shift by index * stride(dim)
166        let base_offset = index * self.stride(dim);
167
168        // Create a view with the same data pointer offset by base_offset
169        // We simulate pointer offset by using memory_offset in access; to preserve zero-copy
170        // semantics, we create a view over the same allocation and adjust shape/strides.
171        let view_shape = crate::tensor::Shape::as_view(new_dims, new_strides);
172        let mut result = self.create_view_with_shape(view_shape);
173
174        // To account for base offset, rebase the `result` by materializing a small view window
175        // via contiguous() if base_offset != 0 for non-view correctness. For simplicity and
176        // correctness across all ops, create a contiguous copy if the base offset is non-zero.
177        if base_offset != 0 {
178            // Materialize contiguous slice
179            let mut contiguous = Tensor::new(result.shape().dims.clone());
180            // Copy elements from self using stride-aware reads starting at base_offset
181            let numel = contiguous.size();
182            let rank2 = result.shape().rank();
183            let mut coords = vec![0usize; rank2];
184            for lin in 0..numel {
185                // Decode coords in result space
186                let mut tmp = lin;
187                for i in (0..rank2).rev() {
188                    let s = result.shape().dims[i];
189                    coords[i] = if s == 0 { 0 } else { tmp % s };
190                    if s != 0 {
191                        tmp /= s;
192                    }
193                }
194                // Map to source coords inserting fixed index at dim
195                let mut src_coords = Vec::with_capacity(rank);
196                for i in 0..rank {
197                    if i == dim {
198                        src_coords.push(index);
199                    } else {
200                        let j = if i < dim { i } else { i - 1 };
201                        src_coords.push(coords[j]);
202                    }
203                }
204                let src_off = self.shape().offset(&src_coords);
205                unsafe {
206                    *contiguous.as_mut_ptr().add(lin) = *self.as_ptr().add(src_off);
207                }
208            }
209            result = contiguous;
210        }
211
212        // GradTrack registration: backward scatters grad_output into zeros at the selected slice
213        if self.requires_grad() {
214            result.set_requires_grad(true);
215            let grad_fn = GradFn::Select {
216                dim,
217                index,
218                input_shape: self.shape().dims.clone(),
219            };
220            result.set_grad_fn(grad_fn.clone());
221            GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
222        }
223
224        result
225    }
226}
227
228#[cfg(test)]
229mod tests {
230    use super::*;
231
232    #[test]
233    fn test_select_basic() {
234        let x = Tensor::from_slice(&[0.0, 1.0, 2.0, 3.0, 4.0, 5.0], vec![2, 3]).unwrap();
235        let s = x.select(0, 1);
236        assert_eq!(s.shape().dims, vec![3]);
237        assert_eq!(s.get(&[0]), 3.0);
238        assert_eq!(s.get(&[2]), 5.0);
239    }
240
241    #[test]
242    fn test_select_grad() {
243        let x = Tensor::from_slice(&[0.0, 1.0, 2.0, 3.0], vec![2, 2])
244            .unwrap()
245            .with_requires_grad();
246        let mut s = x.select(0, 1);
247        s.backward(None);
248        let gx = x.grad_by_value().expect("grad missing");
249        // Only row 1 receives ones
250        assert_eq!(gx.get(&[0, 0]), 0.0);
251        assert_eq!(gx.get(&[0, 1]), 0.0);
252        assert_eq!(gx.get(&[1, 0]), 1.0);
253        assert_eq!(gx.get(&[1, 1]), 1.0);
254    }
255}