train_station/tensor/indexing/select.rs
1use crate::gradtrack::{GradEngine, GradFn};
2use crate::tensor::core::Tensor;
3
4impl Tensor {
5 /// Select a slice along a given dimension at a specific index
6 ///
7 /// This operation extracts a slice from the input tensor by fixing a specific dimension
8 /// at a given index. The result is a tensor with one fewer dimension than the input,
9 /// containing the selected slice.
10 ///
11 /// The select operation returns a view (zero-copy) when the base offset is zero,
12 /// otherwise it creates a contiguous copy to ensure correctness. This operation is
13 /// commonly used for extracting specific rows, columns, or slices from tensors.
14 ///
15 /// # Arguments
16 ///
17 /// * `dim` - The dimension along which to select (must be < tensor rank)
18 /// * `index` - The index within the specified dimension to select (must be < dim size)
19 ///
20 /// # Returns
21 ///
22 /// A tensor with the selected slice. The result has the same shape as the input
23 /// except with the specified dimension removed.
24 ///
25 /// # Examples
26 ///
27 /// ## Basic Row Selection
28 ///
29 /// ```
30 /// use train_station::Tensor;
31 ///
32 /// // Create a 2x3 tensor: [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]
33 /// let tensor = Tensor::from_slice(&[0.0, 1.0, 2.0, 3.0, 4.0, 5.0], vec![2, 3]).unwrap();
34 ///
35 /// // Select row 1 (dimension 0, index 1)
36 /// let result = tensor.select(0, 1);
37 ///
38 /// // Result shape is [3] (dimension 0 removed)
39 /// assert_eq!(result.shape().dims, vec![3]);
40 /// assert_eq!(result.get(&[0]), 3.0); // First element of row 1
41 /// assert_eq!(result.get(&[1]), 4.0); // Second element of row 1
42 /// assert_eq!(result.get(&[2]), 5.0); // Third element of row 1
43 /// ```
44 ///
45 /// ## Column Selection
46 ///
47 /// ```
48 /// use train_station::Tensor;
49 ///
50 /// // Create a 2x3 tensor: [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]
51 /// let tensor = Tensor::from_slice(&[0.0, 1.0, 2.0, 3.0, 4.0, 5.0], vec![2, 3]).unwrap();
52 ///
53 /// // Select column 1 (dimension 1, index 1)
54 /// let result = tensor.select(1, 1);
55 ///
56 /// // Result shape is [2] (dimension 1 removed)
57 /// assert_eq!(result.shape().dims, vec![2]);
58 /// assert_eq!(result.get(&[0]), 1.0); // Column 1, row 0
59 /// assert_eq!(result.get(&[1]), 4.0); // Column 1, row 1
60 /// ```
61 ///
62 /// ## Select with Gradient Tracking
63 ///
64 /// ```
65 /// use train_station::Tensor;
66 ///
67 /// let tensor = Tensor::from_slice(&[0.0, 1.0, 2.0, 3.0], vec![2, 2]).unwrap()
68 /// .with_requires_grad();
69 ///
70 /// // Select row 1 with gradient tracking enabled
71 /// let mut result = tensor.select(0, 1);
72 /// result.backward(None);
73 ///
74 /// // Verify gradients are computed correctly
75 /// let grad = tensor.grad_by_value().expect("gradient missing");
76 /// assert_eq!(grad.shape().dims, vec![2, 2]);
77 /// // Only row 1 receives gradients
78 /// assert_eq!(grad.get(&[0, 0]), 0.0); // Row 0: no gradient
79 /// assert_eq!(grad.get(&[0, 1]), 0.0); // Row 0: no gradient
80 /// assert_eq!(grad.get(&[1, 0]), 1.0); // Row 1: gradient flows
81 /// assert_eq!(grad.get(&[1, 1]), 1.0); // Row 1: gradient flows
82 /// ```
83 ///
84 /// # Performance Characteristics
85 ///
86 /// - **Time Complexity**: O(n) where n is the number of elements in the selected slice
87 /// - **Memory Usage**: Zero-copy view when base offset is zero, otherwise creates a copy
88 /// - **Optimization**: Uses efficient stride-based access for non-contiguous tensors
89 /// - **GradTrack Overhead**: Minimal overhead when gradient tracking is enabled
90 /// - **Memory Layout**: Result is contiguous when a copy is made, view otherwise
91 ///
92 /// # Implementation Details
93 ///
94 /// The select operation works by:
95 /// 1. Validating the dimension and index bounds
96 /// 2. Computing the new shape by removing the selected dimension
97 /// 3. Computing the new strides by removing the selected dimension's stride
98 /// 4. Calculating the base offset for the selected slice
99 /// 5. If base offset is zero: creating a view with adjusted shape/strides
100 /// 6. If base offset is non-zero: creating a contiguous copy of the slice
101 /// 7. Registering the operation for gradient computation if needed
102 ///
103 /// # Safety
104 ///
105 /// This function performs comprehensive bounds checking to ensure:
106 /// - The tensor has non-zero rank
107 /// - The specified dimension is within the tensor's rank
108 /// - The index is within bounds for the specified dimension
109 /// - Memory access is safe through proper offset calculations
110 ///
111 /// # Panics
112 ///
113 /// This function will panic if:
114 /// - The tensor has zero rank
115 /// - `dim` is greater than or equal to the tensor's rank
116 /// - `index` is greater than or equal to the size of the specified dimension
117 ///
118 /// # Thread Safety
119 ///
120 /// This function is thread-safe and can be called concurrently on different tensors.
121 /// The operation does not modify the input tensor and creates either a view or a new tensor.
122 ///
123 /// # View vs Copy Behavior
124 ///
125 /// - **View (zero-copy)**: When the base offset is zero, returns a view that shares
126 /// the same memory as the input tensor with adjusted shape and strides
127 /// - **Copy**: When the base offset is non-zero, creates a contiguous copy to ensure
128 /// correctness across all operations
129 ///
130 /// # GradTrack Behavior
131 ///
132 /// When gradient tracking is enabled:
133 /// - Gradients are scattered back to the selected slice in the input tensor
134 /// - Other positions in the input tensor receive zero gradients
135 /// - This behavior ensures correct gradient flow for the selected elements
136 #[track_caller]
137 pub fn select(&self, dim: usize, index: usize) -> Tensor {
138 let rank = self.shape().rank();
139 assert!(rank > 0, "select requires non-zero rank");
140 assert!(
141 dim < rank,
142 "select dim {} out of bounds for rank {}",
143 dim,
144 rank
145 );
146 let dim_size = self.shape().dims[dim];
147 assert!(
148 index < dim_size,
149 "select index {} out of bounds for dimension {} (size {})",
150 index,
151 dim,
152 dim_size
153 );
154
155 // Build new dims/strides removing the selected dimension
156 let mut new_dims = Vec::with_capacity(rank - 1);
157 let mut new_strides = Vec::with_capacity(rank - 1);
158 for i in 0..rank {
159 if i == dim {
160 continue;
161 }
162 new_dims.push(self.shape().dims[i]);
163 new_strides.push(self.strides()[i]);
164 }
165
166 // Base pointer shift by index * stride(dim)
167 let base_offset = index * self.stride(dim);
168
169 // Create a view with the same data pointer offset by base_offset
170 // We simulate pointer offset by using memory_offset in access; to preserve zero-copy
171 // semantics, we create a view over the same allocation and adjust shape/strides.
172 let view_shape = crate::tensor::Shape::as_view(new_dims, new_strides);
173 let mut result = self.create_view_with_shape(view_shape);
174
175 // To account for base offset, rebase the `result` by materializing a small view window
176 // via contiguous() if base_offset != 0 for non-view correctness. For simplicity and
177 // correctness across all ops, create a contiguous copy if the base offset is non-zero.
178 if base_offset != 0 {
179 // Materialize contiguous slice
180 let mut contiguous = Tensor::new(result.shape().dims.clone());
181 // Copy elements from self using stride-aware reads starting at base_offset
182 let numel = contiguous.size();
183 let rank2 = result.shape().rank();
184 let mut coords = vec![0usize; rank2];
185 for lin in 0..numel {
186 // Decode coords in result space
187 let mut tmp = lin;
188 for i in (0..rank2).rev() {
189 let s = result.shape().dims[i];
190 coords[i] = if s == 0 { 0 } else { tmp % s };
191 if s != 0 {
192 tmp /= s;
193 }
194 }
195 // Map to source coords inserting fixed index at dim
196 let mut src_coords = Vec::with_capacity(rank);
197 for i in 0..rank {
198 if i == dim {
199 src_coords.push(index);
200 } else {
201 let j = if i < dim { i } else { i - 1 };
202 src_coords.push(coords[j]);
203 }
204 }
205 let src_off = self.shape().offset(&src_coords);
206 unsafe {
207 *contiguous.as_mut_ptr().add(lin) = *self.as_ptr().add(src_off);
208 }
209 }
210 result = contiguous;
211 }
212
213 // GradTrack registration: backward scatters grad_output into zeros at the selected slice
214 if self.requires_grad() {
215 result.set_requires_grad(true);
216 let grad_fn = GradFn::Select {
217 dim,
218 index,
219 input_shape: self.shape().dims.clone(),
220 };
221 result.set_grad_fn(grad_fn.clone());
222 GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
223 }
224
225 result
226 }
227}
228
229#[cfg(test)]
230mod tests {
231 use super::*;
232
233 #[test]
234 fn test_select_basic() {
235 let x = Tensor::from_slice(&[0.0, 1.0, 2.0, 3.0, 4.0, 5.0], vec![2, 3]).unwrap();
236 let s = x.select(0, 1);
237 assert_eq!(s.shape().dims, vec![3]);
238 assert_eq!(s.get(&[0]), 3.0);
239 assert_eq!(s.get(&[2]), 5.0);
240 }
241
242 #[test]
243 fn test_select_grad() {
244 let x = Tensor::from_slice(&[0.0, 1.0, 2.0, 3.0], vec![2, 2])
245 .unwrap()
246 .with_requires_grad();
247 let mut s = x.select(0, 1);
248 s.backward(None);
249 let gx = x.grad_by_value().expect("grad missing");
250 // Only row 1 receives ones
251 assert_eq!(gx.get(&[0, 0]), 0.0);
252 assert_eq!(gx.get(&[0, 1]), 0.0);
253 assert_eq!(gx.get(&[1, 0]), 1.0);
254 assert_eq!(gx.get(&[1, 1]), 1.0);
255 }
256}