train_station/tensor/indexing/select.rs
1use crate::gradtrack::{GradEngine, GradFn};
2use crate::tensor::core::Tensor;
3
4impl Tensor {
5 /// Select a slice along a given dimension at a specific index
6 ///
7 /// This operation extracts a slice from the input tensor by fixing a specific dimension
8 /// at a given index. The result is a tensor with one fewer dimension than the input,
9 /// containing the selected slice.
10 ///
11 /// The select operation returns a view (zero-copy) when the base offset is zero,
12 /// otherwise it creates a contiguous copy to ensure correctness. This operation is
13 /// commonly used for extracting specific rows, columns, or slices from tensors.
14 ///
15 /// # Arguments
16 ///
17 /// * `dim` - The dimension along which to select (must be < tensor rank)
18 /// * `index` - The index within the specified dimension to select (must be < dim size)
19 ///
20 /// # Returns
21 ///
22 /// A tensor with the selected slice. The result has the same shape as the input
23 /// except with the specified dimension removed.
24 ///
25 /// # Examples
26 ///
27 /// ## Basic Row Selection
28 ///
29 /// ```
30 /// use train_station::Tensor;
31 ///
32 /// // Create a 2x3 tensor: [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]
33 /// let tensor = Tensor::from_slice(&[0.0, 1.0, 2.0, 3.0, 4.0, 5.0], vec![2, 3]).unwrap();
34 ///
35 /// // Select row 1 (dimension 0, index 1)
36 /// let result = tensor.select(0, 1);
37 ///
38 /// // Result shape is [3] (dimension 0 removed)
39 /// assert_eq!(result.shape().dims, vec![3]);
40 /// assert_eq!(result.get(&[0]), 3.0); // First element of row 1
41 /// assert_eq!(result.get(&[1]), 4.0); // Second element of row 1
42 /// assert_eq!(result.get(&[2]), 5.0); // Third element of row 1
43 /// ```
44 ///
45 /// ## Column Selection
46 ///
47 /// ```
48 /// use train_station::Tensor;
49 ///
50 /// // Create a 2x3 tensor: [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]
51 /// let tensor = Tensor::from_slice(&[0.0, 1.0, 2.0, 3.0, 4.0, 5.0], vec![2, 3]).unwrap();
52 ///
53 /// // Select column 1 (dimension 1, index 1)
54 /// let result = tensor.select(1, 1);
55 ///
56 /// // Result shape is [2] (dimension 1 removed)
57 /// assert_eq!(result.shape().dims, vec![2]);
58 /// assert_eq!(result.get(&[0]), 1.0); // Column 1, row 0
59 /// assert_eq!(result.get(&[1]), 4.0); // Column 1, row 1
60 /// ```
61 ///
62 /// ## Select with Gradient Tracking
63 ///
64 /// ```
65 /// use train_station::Tensor;
66 ///
67 /// let tensor = Tensor::from_slice(&[0.0, 1.0, 2.0, 3.0], vec![2, 2]).unwrap()
68 /// .with_requires_grad();
69 ///
70 /// // Select row 1 with gradient tracking enabled
71 /// let mut result = tensor.select(0, 1);
72 /// result.backward(None);
73 ///
74 /// // Verify gradients are computed correctly
75 /// let grad = tensor.grad_by_value().expect("gradient missing");
76 /// assert_eq!(grad.shape().dims, vec![2, 2]);
77 /// // Only row 1 receives gradients
78 /// assert_eq!(grad.get(&[0, 0]), 0.0); // Row 0: no gradient
79 /// assert_eq!(grad.get(&[0, 1]), 0.0); // Row 0: no gradient
80 /// assert_eq!(grad.get(&[1, 0]), 1.0); // Row 1: gradient flows
81 /// assert_eq!(grad.get(&[1, 1]), 1.0); // Row 1: gradient flows
82 /// ```
83 ///
84 /// # Performance Characteristics
85 ///
86 /// - **Time Complexity**: O(n) where n is the number of elements in the selected slice
87 /// - **Memory Usage**: Zero-copy view when base offset is zero, otherwise creates a copy
88 /// - **Optimization**: Uses efficient stride-based access for non-contiguous tensors
89 /// - **GradTrack Overhead**: Minimal overhead when gradient tracking is enabled
90 /// - **Memory Layout**: Result is contiguous when a copy is made, view otherwise
91 ///
92 /// # Implementation Details
93 ///
94 /// The select operation works by:
95 /// 1. Validating the dimension and index bounds
96 /// 2. Computing the new shape by removing the selected dimension
97 /// 3. Computing the new strides by removing the selected dimension's stride
98 /// 4. Calculating the base offset for the selected slice
99 /// 5. If base offset is zero: creating a view with adjusted shape/strides
100 /// 6. If base offset is non-zero: creating a contiguous copy of the slice
101 /// 7. Registering the operation for gradient computation if needed
102 ///
103 /// # Safety
104 ///
105 /// This function performs comprehensive bounds checking to ensure:
106 /// - The tensor has non-zero rank
107 /// - The specified dimension is within the tensor's rank
108 /// - The index is within bounds for the specified dimension
109 /// - Memory access is safe through proper offset calculations
110 ///
111 /// # Panics
112 ///
113 /// This function will panic if:
114 /// - The tensor has zero rank
115 /// - `dim` is greater than or equal to the tensor's rank
116 /// - `index` is greater than or equal to the size of the specified dimension
117 ///
118 /// # Thread Safety
119 ///
120 /// This function is thread-safe and can be called concurrently on different tensors.
121 /// The operation does not modify the input tensor and creates either a view or a new tensor.
122 ///
123 /// # View vs Copy Behavior
124 ///
125 /// - **View (zero-copy)**: When the base offset is zero, returns a view that shares
126 /// the same memory as the input tensor with adjusted shape and strides
127 /// - **Copy**: When the base offset is non-zero, creates a contiguous copy to ensure
128 /// correctness across all operations
129 ///
130 /// # GradTrack Behavior
131 ///
132 /// When gradient tracking is enabled:
133 /// - Gradients are scattered back to the selected slice in the input tensor
134 /// - Other positions in the input tensor receive zero gradients
135 /// - This behavior ensures correct gradient flow for the selected elements
136 pub fn select(&self, dim: usize, index: usize) -> Tensor {
137 let rank = self.shape().rank();
138 assert!(rank > 0, "select requires non-zero rank");
139 assert!(
140 dim < rank,
141 "select dim {} out of bounds for rank {}",
142 dim,
143 rank
144 );
145 let dim_size = self.shape().dims[dim];
146 assert!(
147 index < dim_size,
148 "select index {} out of bounds for dimension {} (size {})",
149 index,
150 dim,
151 dim_size
152 );
153
154 // Build new dims/strides removing the selected dimension
155 let mut new_dims = Vec::with_capacity(rank - 1);
156 let mut new_strides = Vec::with_capacity(rank - 1);
157 for i in 0..rank {
158 if i == dim {
159 continue;
160 }
161 new_dims.push(self.shape().dims[i]);
162 new_strides.push(self.strides()[i]);
163 }
164
165 // Base pointer shift by index * stride(dim)
166 let base_offset = index * self.stride(dim);
167
168 // Create a view with the same data pointer offset by base_offset
169 // We simulate pointer offset by using memory_offset in access; to preserve zero-copy
170 // semantics, we create a view over the same allocation and adjust shape/strides.
171 let view_shape = crate::tensor::Shape::as_view(new_dims, new_strides);
172 let mut result = self.create_view_with_shape(view_shape);
173
174 // To account for base offset, rebase the `result` by materializing a small view window
175 // via contiguous() if base_offset != 0 for non-view correctness. For simplicity and
176 // correctness across all ops, create a contiguous copy if the base offset is non-zero.
177 if base_offset != 0 {
178 // Materialize contiguous slice
179 let mut contiguous = Tensor::new(result.shape().dims.clone());
180 // Copy elements from self using stride-aware reads starting at base_offset
181 let numel = contiguous.size();
182 let rank2 = result.shape().rank();
183 let mut coords = vec![0usize; rank2];
184 for lin in 0..numel {
185 // Decode coords in result space
186 let mut tmp = lin;
187 for i in (0..rank2).rev() {
188 let s = result.shape().dims[i];
189 coords[i] = if s == 0 { 0 } else { tmp % s };
190 if s != 0 {
191 tmp /= s;
192 }
193 }
194 // Map to source coords inserting fixed index at dim
195 let mut src_coords = Vec::with_capacity(rank);
196 for i in 0..rank {
197 if i == dim {
198 src_coords.push(index);
199 } else {
200 let j = if i < dim { i } else { i - 1 };
201 src_coords.push(coords[j]);
202 }
203 }
204 let src_off = self.shape().offset(&src_coords);
205 unsafe {
206 *contiguous.as_mut_ptr().add(lin) = *self.as_ptr().add(src_off);
207 }
208 }
209 result = contiguous;
210 }
211
212 // GradTrack registration: backward scatters grad_output into zeros at the selected slice
213 if self.requires_grad() {
214 result.set_requires_grad(true);
215 let grad_fn = GradFn::Select {
216 dim,
217 index,
218 input_shape: self.shape().dims.clone(),
219 };
220 result.set_grad_fn(grad_fn.clone());
221 GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
222 }
223
224 result
225 }
226}
227
228#[cfg(test)]
229mod tests {
230 use super::*;
231
232 #[test]
233 fn test_select_basic() {
234 let x = Tensor::from_slice(&[0.0, 1.0, 2.0, 3.0, 4.0, 5.0], vec![2, 3]).unwrap();
235 let s = x.select(0, 1);
236 assert_eq!(s.shape().dims, vec![3]);
237 assert_eq!(s.get(&[0]), 3.0);
238 assert_eq!(s.get(&[2]), 5.0);
239 }
240
241 #[test]
242 fn test_select_grad() {
243 let x = Tensor::from_slice(&[0.0, 1.0, 2.0, 3.0], vec![2, 2])
244 .unwrap()
245 .with_requires_grad();
246 let mut s = x.select(0, 1);
247 s.backward(None);
248 let gx = x.grad_by_value().expect("grad missing");
249 // Only row 1 receives ones
250 assert_eq!(gx.get(&[0, 0]), 0.0);
251 assert_eq!(gx.get(&[0, 1]), 0.0);
252 assert_eq!(gx.get(&[1, 0]), 1.0);
253 assert_eq!(gx.get(&[1, 1]), 1.0);
254 }
255}