train_station/tensor/indexing/
select.rs

1use crate::gradtrack::{is_grad_enabled, GradEngine, GradFn};
2use crate::tensor::core::Tensor;
3
4impl Tensor {
5    /// Select a slice along a given dimension at a specific index
6    ///
7    /// This operation extracts a slice from the input tensor by fixing a specific dimension
8    /// at a given index. The result is a tensor with one fewer dimension than the input,
9    /// containing the selected slice.
10    ///
11    /// The select operation returns a view (zero-copy) when the base offset is zero,
12    /// otherwise it creates a contiguous copy to ensure correctness. This operation is
13    /// commonly used for extracting specific rows, columns, or slices from tensors.
14    ///
15    /// # Arguments
16    ///
17    /// * `dim` - The dimension along which to select (must be < tensor rank)
18    /// * `index` - The index within the specified dimension to select (must be < dim size)
19    ///
20    /// # Returns
21    ///
22    /// A tensor with the selected slice. The result has the same shape as the input
23    /// except with the specified dimension removed.
24    ///
25    /// # Performance
26    ///
27    /// Returns a view when possible (base offset is zero) to avoid copying. On
28    /// non-zero offsets, falls back to a contiguous copy for correctness. Gradients
29    /// propagate back to the selected slice when GradTrack is enabled.
30    ///
31    /// # Examples
32    ///
33    /// ## Basic Row Selection
34    ///
35    /// ```
36    /// use train_station::Tensor;
37    ///
38    /// // Create a 2x3 tensor: [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]
39    /// let tensor = Tensor::from_slice(&[0.0, 1.0, 2.0, 3.0, 4.0, 5.0], vec![2, 3]).unwrap();
40    ///
41    /// // Select row 1 (dimension 0, index 1)
42    /// let result = tensor.select(0, 1);
43    ///
44    /// // Result shape is [3] (dimension 0 removed)
45    /// assert_eq!(result.shape().dims(), vec![3]);
46    /// assert_eq!(result.get(&[0]), 3.0);  // First element of row 1
47    /// assert_eq!(result.get(&[1]), 4.0);  // Second element of row 1
48    /// assert_eq!(result.get(&[2]), 5.0);  // Third element of row 1
49    /// ```
50    ///
51    /// ## Column Selection
52    ///
53    /// ```
54    /// use train_station::Tensor;
55    ///
56    /// // Create a 2x3 tensor: [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]
57    /// let tensor = Tensor::from_slice(&[0.0, 1.0, 2.0, 3.0, 4.0, 5.0], vec![2, 3]).unwrap();
58    ///
59    /// // Select column 1 (dimension 1, index 1)
60    /// let result = tensor.select(1, 1);
61    ///
62    /// // Result shape is [2] (dimension 1 removed)
63    /// assert_eq!(result.shape().dims(), vec![2]);
64    /// assert_eq!(result.get(&[0]), 1.0);  // Column 1, row 0
65    /// assert_eq!(result.get(&[1]), 4.0);  // Column 1, row 1
66    /// ```
67    ///
68    /// ## Select with Gradient Tracking
69    ///
70    /// ```
71    /// use train_station::Tensor;
72    ///
73    /// let tensor = Tensor::from_slice(&[0.0, 1.0, 2.0, 3.0], vec![2, 2]).unwrap()
74    ///     .with_requires_grad();
75    ///
76    /// // Select row 1 with gradient tracking enabled
77    /// let mut result = tensor.select(0, 1);
78    /// result.backward(None);
79    ///
80    /// // Verify gradients are computed correctly
81    /// let grad = tensor.grad_owned().expect("gradient missing");
82    /// assert_eq!(grad.shape().dims(), vec![2, 2]);
83    /// // Only row 1 receives gradients
84    /// assert_eq!(grad.get(&[0, 0]), 0.0);  // Row 0: no gradient
85    /// assert_eq!(grad.get(&[0, 1]), 0.0);  // Row 0: no gradient
86    /// assert_eq!(grad.get(&[1, 0]), 1.0);  // Row 1: gradient flows
87    /// assert_eq!(grad.get(&[1, 1]), 1.0);  // Row 1: gradient flows
88    /// ```
89    ///
90    /// # Performance Characteristics
91    ///
92    /// - **Time Complexity**: O(n) where n is the number of elements in the selected slice
93    /// - **Memory Usage**: Zero-copy view when base offset is zero, otherwise creates a copy
94    /// - **Optimization**: Uses efficient stride-based access for non-contiguous tensors
95    /// - **GradTrack Overhead**: Minimal overhead when gradient tracking is enabled
96    /// - **Memory Layout**: Result is contiguous when a copy is made, view otherwise
97    ///
98    /// # Implementation Details
99    ///
100    /// The select operation works by:
101    /// 1. Validating the dimension and index bounds
102    /// 2. Computing the new shape by removing the selected dimension
103    /// 3. Computing the new strides by removing the selected dimension's stride
104    /// 4. Calculating the base offset for the selected slice
105    /// 5. If base offset is zero: creating a view with adjusted shape/strides
106    /// 6. If base offset is non-zero: creating a contiguous copy of the slice
107    /// 7. Registering the operation for gradient computation if needed
108    ///
109    /// # Safety
110    ///
111    /// This function performs comprehensive bounds checking to ensure:
112    /// - The tensor has non-zero rank
113    /// - The specified dimension is within the tensor's rank
114    /// - The index is within bounds for the specified dimension
115    /// - Memory access is safe through proper offset calculations
116    ///
117    /// # Panics
118    ///
119    /// This function will panic if:
120    /// - The tensor has zero rank
121    /// - `dim` is greater than or equal to the tensor's rank
122    /// - `index` is greater than or equal to the size of the specified dimension
123    ///
124    /// # Thread Safety
125    ///
126    /// This function is thread-safe and can be called concurrently on different tensors.
127    /// The operation does not modify the input tensor and creates either a view or a new tensor.
128    ///
129    /// # View vs Copy Behavior
130    ///
131    /// - **View (zero-copy)**: When the base offset is zero, returns a view that shares
132    ///   the same memory as the input tensor with adjusted shape and strides
133    /// - **Copy**: When the base offset is non-zero, creates a contiguous copy to ensure
134    ///   correctness across all operations
135    ///
136    /// # GradTrack Behavior
137    ///
138    /// When gradient tracking is enabled:
139    /// - Gradients are scattered back to the selected slice in the input tensor
140    /// - Other positions in the input tensor receive zero gradients
141    /// - This behavior ensures correct gradient flow for the selected elements
142    #[track_caller]
143    pub fn select(&self, dim: usize, index: usize) -> Tensor {
144        let rank = self.shape().rank();
145        assert!(rank > 0, "select requires non-zero rank");
146        assert!(
147            dim < rank,
148            "select dim {} out of bounds for rank {}",
149            dim,
150            rank
151        );
152        let dim_size = self.shape().dims()[dim];
153        assert!(
154            index < dim_size,
155            "select index {} out of bounds for dimension {} (size {})",
156            index,
157            dim,
158            dim_size
159        );
160
161        // Build new dims/strides removing the selected dimension
162        let mut new_dims = Vec::with_capacity(rank - 1);
163        let mut new_strides = Vec::with_capacity(rank - 1);
164        for i in 0..rank {
165            if i == dim {
166                continue;
167            }
168            new_dims.push(self.shape().dims()[i]);
169            new_strides.push(self.strides()[i]);
170        }
171
172        // Base pointer shift by index * stride(dim)
173        let base_offset = index * self.stride(dim);
174
175        // Create zero-copy view using core as_strided with storage_offset
176        let mut result = match crate::tensor::core::view::as_strided_view(
177            self,
178            &new_dims,
179            &new_strides,
180            base_offset,
181        ) {
182            Ok(v) => v,
183            Err(e) => panic!("select view error: {:?}", e),
184        };
185
186        // GradTrack registration only when gradients are enabled and input requires grad
187        if self.requires_grad() && is_grad_enabled() {
188            result.set_requires_grad(true);
189            let grad_fn = GradFn::Select {
190                dim,
191                index,
192                input_shape: self.shape().dims().to_vec(),
193            };
194            result.set_grad_fn(grad_fn.clone());
195            GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
196        }
197
198        result
199    }
200}
201
202#[cfg(test)]
203mod tests {
204    use super::*;
205
206    #[test]
207    fn test_select_basic() {
208        let x = Tensor::from_slice(&[0.0, 1.0, 2.0, 3.0, 4.0, 5.0], vec![2, 3]).unwrap();
209        let s = x.select(0, 1);
210        assert_eq!(s.shape().dims(), vec![3]);
211        assert_eq!(s.get(&[0]), 3.0);
212        assert_eq!(s.get(&[2]), 5.0);
213    }
214
215    #[test]
216    fn test_select_grad() {
217        let x = Tensor::from_slice(&[0.0, 1.0, 2.0, 3.0], vec![2, 2])
218            .unwrap()
219            .with_requires_grad();
220        let mut s = x.select(0, 1);
221        s.backward(None);
222        let gx = x.grad_owned().expect("grad missing");
223        // Only row 1 receives ones
224        assert_eq!(gx.get(&[0, 0]), 0.0);
225        assert_eq!(gx.get(&[0, 1]), 0.0);
226        assert_eq!(gx.get(&[1, 0]), 1.0);
227        assert_eq!(gx.get(&[1, 1]), 1.0);
228    }
229}