train_station/tensor/indexing/select.rs
1use crate::gradtrack::{is_grad_enabled, GradEngine, GradFn};
2use crate::tensor::core::Tensor;
3
4impl Tensor {
5 /// Select a slice along a given dimension at a specific index
6 ///
7 /// This operation extracts a slice from the input tensor by fixing a specific dimension
8 /// at a given index. The result is a tensor with one fewer dimension than the input,
9 /// containing the selected slice.
10 ///
11 /// The select operation returns a view (zero-copy) when the base offset is zero,
12 /// otherwise it creates a contiguous copy to ensure correctness. This operation is
13 /// commonly used for extracting specific rows, columns, or slices from tensors.
14 ///
15 /// # Arguments
16 ///
17 /// * `dim` - The dimension along which to select (must be < tensor rank)
18 /// * `index` - The index within the specified dimension to select (must be < dim size)
19 ///
20 /// # Returns
21 ///
22 /// A tensor with the selected slice. The result has the same shape as the input
23 /// except with the specified dimension removed.
24 ///
25 /// # Performance
26 ///
27 /// Returns a view when possible (base offset is zero) to avoid copying. On
28 /// non-zero offsets, falls back to a contiguous copy for correctness. Gradients
29 /// propagate back to the selected slice when GradTrack is enabled.
30 ///
31 /// # Examples
32 ///
33 /// ## Basic Row Selection
34 ///
35 /// ```
36 /// use train_station::Tensor;
37 ///
38 /// // Create a 2x3 tensor: [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]
39 /// let tensor = Tensor::from_slice(&[0.0, 1.0, 2.0, 3.0, 4.0, 5.0], vec![2, 3]).unwrap();
40 ///
41 /// // Select row 1 (dimension 0, index 1)
42 /// let result = tensor.select(0, 1);
43 ///
44 /// // Result shape is [3] (dimension 0 removed)
45 /// assert_eq!(result.shape().dims(), vec![3]);
46 /// assert_eq!(result.get(&[0]), 3.0); // First element of row 1
47 /// assert_eq!(result.get(&[1]), 4.0); // Second element of row 1
48 /// assert_eq!(result.get(&[2]), 5.0); // Third element of row 1
49 /// ```
50 ///
51 /// ## Column Selection
52 ///
53 /// ```
54 /// use train_station::Tensor;
55 ///
56 /// // Create a 2x3 tensor: [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]
57 /// let tensor = Tensor::from_slice(&[0.0, 1.0, 2.0, 3.0, 4.0, 5.0], vec![2, 3]).unwrap();
58 ///
59 /// // Select column 1 (dimension 1, index 1)
60 /// let result = tensor.select(1, 1);
61 ///
62 /// // Result shape is [2] (dimension 1 removed)
63 /// assert_eq!(result.shape().dims(), vec![2]);
64 /// assert_eq!(result.get(&[0]), 1.0); // Column 1, row 0
65 /// assert_eq!(result.get(&[1]), 4.0); // Column 1, row 1
66 /// ```
67 ///
68 /// ## Select with Gradient Tracking
69 ///
70 /// ```
71 /// use train_station::Tensor;
72 ///
73 /// let tensor = Tensor::from_slice(&[0.0, 1.0, 2.0, 3.0], vec![2, 2]).unwrap()
74 /// .with_requires_grad();
75 ///
76 /// // Select row 1 with gradient tracking enabled
77 /// let mut result = tensor.select(0, 1);
78 /// result.backward(None);
79 ///
80 /// // Verify gradients are computed correctly
81 /// let grad = tensor.grad_owned().expect("gradient missing");
82 /// assert_eq!(grad.shape().dims(), vec![2, 2]);
83 /// // Only row 1 receives gradients
84 /// assert_eq!(grad.get(&[0, 0]), 0.0); // Row 0: no gradient
85 /// assert_eq!(grad.get(&[0, 1]), 0.0); // Row 0: no gradient
86 /// assert_eq!(grad.get(&[1, 0]), 1.0); // Row 1: gradient flows
87 /// assert_eq!(grad.get(&[1, 1]), 1.0); // Row 1: gradient flows
88 /// ```
89 ///
90 /// # Performance Characteristics
91 ///
92 /// - **Time Complexity**: O(n) where n is the number of elements in the selected slice
93 /// - **Memory Usage**: Zero-copy view when base offset is zero, otherwise creates a copy
94 /// - **Optimization**: Uses efficient stride-based access for non-contiguous tensors
95 /// - **GradTrack Overhead**: Minimal overhead when gradient tracking is enabled
96 /// - **Memory Layout**: Result is contiguous when a copy is made, view otherwise
97 ///
98 /// # Implementation Details
99 ///
100 /// The select operation works by:
101 /// 1. Validating the dimension and index bounds
102 /// 2. Computing the new shape by removing the selected dimension
103 /// 3. Computing the new strides by removing the selected dimension's stride
104 /// 4. Calculating the base offset for the selected slice
105 /// 5. If base offset is zero: creating a view with adjusted shape/strides
106 /// 6. If base offset is non-zero: creating a contiguous copy of the slice
107 /// 7. Registering the operation for gradient computation if needed
108 ///
109 /// # Safety
110 ///
111 /// This function performs comprehensive bounds checking to ensure:
112 /// - The tensor has non-zero rank
113 /// - The specified dimension is within the tensor's rank
114 /// - The index is within bounds for the specified dimension
115 /// - Memory access is safe through proper offset calculations
116 ///
117 /// # Panics
118 ///
119 /// This function will panic if:
120 /// - The tensor has zero rank
121 /// - `dim` is greater than or equal to the tensor's rank
122 /// - `index` is greater than or equal to the size of the specified dimension
123 ///
124 /// # Thread Safety
125 ///
126 /// This function is thread-safe and can be called concurrently on different tensors.
127 /// The operation does not modify the input tensor and creates either a view or a new tensor.
128 ///
129 /// # View vs Copy Behavior
130 ///
131 /// - **View (zero-copy)**: When the base offset is zero, returns a view that shares
132 /// the same memory as the input tensor with adjusted shape and strides
133 /// - **Copy**: When the base offset is non-zero, creates a contiguous copy to ensure
134 /// correctness across all operations
135 ///
136 /// # GradTrack Behavior
137 ///
138 /// When gradient tracking is enabled:
139 /// - Gradients are scattered back to the selected slice in the input tensor
140 /// - Other positions in the input tensor receive zero gradients
141 /// - This behavior ensures correct gradient flow for the selected elements
142 #[track_caller]
143 pub fn select(&self, dim: usize, index: usize) -> Tensor {
144 let rank = self.shape().rank();
145 assert!(rank > 0, "select requires non-zero rank");
146 assert!(
147 dim < rank,
148 "select dim {} out of bounds for rank {}",
149 dim,
150 rank
151 );
152 let dim_size = self.shape().dims()[dim];
153 assert!(
154 index < dim_size,
155 "select index {} out of bounds for dimension {} (size {})",
156 index,
157 dim,
158 dim_size
159 );
160
161 // Build new dims/strides removing the selected dimension
162 let mut new_dims = Vec::with_capacity(rank - 1);
163 let mut new_strides = Vec::with_capacity(rank - 1);
164 for i in 0..rank {
165 if i == dim {
166 continue;
167 }
168 new_dims.push(self.shape().dims()[i]);
169 new_strides.push(self.strides()[i]);
170 }
171
172 // Base pointer shift by index * stride(dim)
173 let base_offset = index * self.stride(dim);
174
175 // Create zero-copy view using core as_strided with storage_offset
176 let mut result = match crate::tensor::core::view::as_strided_view(
177 self,
178 &new_dims,
179 &new_strides,
180 base_offset,
181 ) {
182 Ok(v) => v,
183 Err(e) => panic!("select view error: {:?}", e),
184 };
185
186 // GradTrack registration only when gradients are enabled and input requires grad
187 if self.requires_grad() && is_grad_enabled() {
188 result.set_requires_grad(true);
189 let grad_fn = GradFn::Select {
190 dim,
191 index,
192 input_shape: self.shape().dims().to_vec(),
193 };
194 result.set_grad_fn(grad_fn.clone());
195 GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
196 }
197
198 result
199 }
200}
201
202#[cfg(test)]
203mod tests {
204 use super::*;
205
206 #[test]
207 fn test_select_basic() {
208 let x = Tensor::from_slice(&[0.0, 1.0, 2.0, 3.0, 4.0, 5.0], vec![2, 3]).unwrap();
209 let s = x.select(0, 1);
210 assert_eq!(s.shape().dims(), vec![3]);
211 assert_eq!(s.get(&[0]), 3.0);
212 assert_eq!(s.get(&[2]), 5.0);
213 }
214
215 #[test]
216 fn test_select_grad() {
217 let x = Tensor::from_slice(&[0.0, 1.0, 2.0, 3.0], vec![2, 2])
218 .unwrap()
219 .with_requires_grad();
220 let mut s = x.select(0, 1);
221 s.backward(None);
222 let gx = x.grad_owned().expect("grad missing");
223 // Only row 1 receives ones
224 assert_eq!(gx.get(&[0, 0]), 0.0);
225 assert_eq!(gx.get(&[0, 1]), 0.0);
226 assert_eq!(gx.get(&[1, 0]), 1.0);
227 assert_eq!(gx.get(&[1, 1]), 1.0);
228 }
229}