train_station/tensor/indexing/
select.rs

1use crate::gradtrack::{GradEngine, GradFn};
2use crate::tensor::core::Tensor;
3
4impl Tensor {
5    /// Select a slice along a given dimension at a specific index
6    ///
7    /// This operation extracts a slice from the input tensor by fixing a specific dimension
8    /// at a given index. The result is a tensor with one fewer dimension than the input,
9    /// containing the selected slice.
10    ///
11    /// The select operation returns a view (zero-copy) when the base offset is zero,
12    /// otherwise it creates a contiguous copy to ensure correctness. This operation is
13    /// commonly used for extracting specific rows, columns, or slices from tensors.
14    ///
15    /// # Arguments
16    ///
17    /// * `dim` - The dimension along which to select (must be < tensor rank)
18    /// * `index` - The index within the specified dimension to select (must be < dim size)
19    ///
20    /// # Returns
21    ///
22    /// A tensor with the selected slice. The result has the same shape as the input
23    /// except with the specified dimension removed.
24    ///
25    /// # Examples
26    ///
27    /// ## Basic Row Selection
28    ///
29    /// ```
30    /// use train_station::Tensor;
31    ///
32    /// // Create a 2x3 tensor: [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]
33    /// let tensor = Tensor::from_slice(&[0.0, 1.0, 2.0, 3.0, 4.0, 5.0], vec![2, 3]).unwrap();
34    ///
35    /// // Select row 1 (dimension 0, index 1)
36    /// let result = tensor.select(0, 1);
37    ///
38    /// // Result shape is [3] (dimension 0 removed)
39    /// assert_eq!(result.shape().dims(), vec![3]);
40    /// assert_eq!(result.get(&[0]), 3.0);  // First element of row 1
41    /// assert_eq!(result.get(&[1]), 4.0);  // Second element of row 1
42    /// assert_eq!(result.get(&[2]), 5.0);  // Third element of row 1
43    /// ```
44    ///
45    /// ## Column Selection
46    ///
47    /// ```
48    /// use train_station::Tensor;
49    ///
50    /// // Create a 2x3 tensor: [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]
51    /// let tensor = Tensor::from_slice(&[0.0, 1.0, 2.0, 3.0, 4.0, 5.0], vec![2, 3]).unwrap();
52    ///
53    /// // Select column 1 (dimension 1, index 1)
54    /// let result = tensor.select(1, 1);
55    ///
56    /// // Result shape is [2] (dimension 1 removed)
57    /// assert_eq!(result.shape().dims(), vec![2]);
58    /// assert_eq!(result.get(&[0]), 1.0);  // Column 1, row 0
59    /// assert_eq!(result.get(&[1]), 4.0);  // Column 1, row 1
60    /// ```
61    ///
62    /// ## Select with Gradient Tracking
63    ///
64    /// ```
65    /// use train_station::Tensor;
66    ///
67    /// let tensor = Tensor::from_slice(&[0.0, 1.0, 2.0, 3.0], vec![2, 2]).unwrap()
68    ///     .with_requires_grad();
69    ///
70    /// // Select row 1 with gradient tracking enabled
71    /// let mut result = tensor.select(0, 1);
72    /// result.backward(None);
73    ///
74    /// // Verify gradients are computed correctly
75    /// let grad = tensor.grad_owned().expect("gradient missing");
76    /// assert_eq!(grad.shape().dims(), vec![2, 2]);
77    /// // Only row 1 receives gradients
78    /// assert_eq!(grad.get(&[0, 0]), 0.0);  // Row 0: no gradient
79    /// assert_eq!(grad.get(&[0, 1]), 0.0);  // Row 0: no gradient
80    /// assert_eq!(grad.get(&[1, 0]), 1.0);  // Row 1: gradient flows
81    /// assert_eq!(grad.get(&[1, 1]), 1.0);  // Row 1: gradient flows
82    /// ```
83    ///
84    /// # Performance Characteristics
85    ///
86    /// - **Time Complexity**: O(n) where n is the number of elements in the selected slice
87    /// - **Memory Usage**: Zero-copy view when base offset is zero, otherwise creates a copy
88    /// - **Optimization**: Uses efficient stride-based access for non-contiguous tensors
89    /// - **GradTrack Overhead**: Minimal overhead when gradient tracking is enabled
90    /// - **Memory Layout**: Result is contiguous when a copy is made, view otherwise
91    ///
92    /// # Implementation Details
93    ///
94    /// The select operation works by:
95    /// 1. Validating the dimension and index bounds
96    /// 2. Computing the new shape by removing the selected dimension
97    /// 3. Computing the new strides by removing the selected dimension's stride
98    /// 4. Calculating the base offset for the selected slice
99    /// 5. If base offset is zero: creating a view with adjusted shape/strides
100    /// 6. If base offset is non-zero: creating a contiguous copy of the slice
101    /// 7. Registering the operation for gradient computation if needed
102    ///
103    /// # Safety
104    ///
105    /// This function performs comprehensive bounds checking to ensure:
106    /// - The tensor has non-zero rank
107    /// - The specified dimension is within the tensor's rank
108    /// - The index is within bounds for the specified dimension
109    /// - Memory access is safe through proper offset calculations
110    ///
111    /// # Panics
112    ///
113    /// This function will panic if:
114    /// - The tensor has zero rank
115    /// - `dim` is greater than or equal to the tensor's rank
116    /// - `index` is greater than or equal to the size of the specified dimension
117    ///
118    /// # Thread Safety
119    ///
120    /// This function is thread-safe and can be called concurrently on different tensors.
121    /// The operation does not modify the input tensor and creates either a view or a new tensor.
122    ///
123    /// # View vs Copy Behavior
124    ///
125    /// - **View (zero-copy)**: When the base offset is zero, returns a view that shares
126    ///   the same memory as the input tensor with adjusted shape and strides
127    /// - **Copy**: When the base offset is non-zero, creates a contiguous copy to ensure
128    ///   correctness across all operations
129    ///
130    /// # GradTrack Behavior
131    ///
132    /// When gradient tracking is enabled:
133    /// - Gradients are scattered back to the selected slice in the input tensor
134    /// - Other positions in the input tensor receive zero gradients
135    /// - This behavior ensures correct gradient flow for the selected elements
136    #[track_caller]
137    pub fn select(&self, dim: usize, index: usize) -> Tensor {
138        let rank = self.shape().rank();
139        assert!(rank > 0, "select requires non-zero rank");
140        assert!(
141            dim < rank,
142            "select dim {} out of bounds for rank {}",
143            dim,
144            rank
145        );
146        let dim_size = self.shape().dims()[dim];
147        assert!(
148            index < dim_size,
149            "select index {} out of bounds for dimension {} (size {})",
150            index,
151            dim,
152            dim_size
153        );
154
155        // Build new dims/strides removing the selected dimension
156        let mut new_dims = Vec::with_capacity(rank - 1);
157        let mut new_strides = Vec::with_capacity(rank - 1);
158        for i in 0..rank {
159            if i == dim {
160                continue;
161            }
162            new_dims.push(self.shape().dims()[i]);
163            new_strides.push(self.strides()[i]);
164        }
165
166        // Base pointer shift by index * stride(dim)
167        let base_offset = index * self.stride(dim);
168
169        // Create zero-copy view using core as_strided with storage_offset
170        let mut result = match crate::tensor::core::view::as_strided_view(
171            self,
172            &new_dims,
173            &new_strides,
174            base_offset,
175        ) {
176            Ok(v) => v,
177            Err(e) => panic!("select view error: {:?}", e),
178        };
179
180        // GradTrack registration: backward scatters grad_output into zeros at the selected slice
181        if self.requires_grad() {
182            result.set_requires_grad(true);
183            let grad_fn = GradFn::Select {
184                dim,
185                index,
186                input_shape: self.shape().dims().to_vec(),
187            };
188            result.set_grad_fn(grad_fn.clone());
189            GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
190        }
191
192        result
193    }
194}
195
196#[cfg(test)]
197mod tests {
198    use super::*;
199
200    #[test]
201    fn test_select_basic() {
202        let x = Tensor::from_slice(&[0.0, 1.0, 2.0, 3.0, 4.0, 5.0], vec![2, 3]).unwrap();
203        let s = x.select(0, 1);
204        assert_eq!(s.shape().dims(), vec![3]);
205        assert_eq!(s.get(&[0]), 3.0);
206        assert_eq!(s.get(&[2]), 5.0);
207    }
208
209    #[test]
210    fn test_select_grad() {
211        let x = Tensor::from_slice(&[0.0, 1.0, 2.0, 3.0], vec![2, 2])
212            .unwrap()
213            .with_requires_grad();
214        let mut s = x.select(0, 1);
215        s.backward(None);
216        let gx = x.grad_owned().expect("grad missing");
217        // Only row 1 receives ones
218        assert_eq!(gx.get(&[0, 0]), 0.0);
219        assert_eq!(gx.get(&[0, 1]), 0.0);
220        assert_eq!(gx.get(&[1, 0]), 1.0);
221        assert_eq!(gx.get(&[1, 1]), 1.0);
222    }
223}