train_station/tensor/transform/unsqueeze.rs
1//! Tensor unsqueeze operations
2//!
3//! This module provides tensor unsqueeze functionality that adds dimensions
4//! of size 1 to tensors, effectively increasing the dimensionality while
5//! preserving the total number of elements. Unsqueezing is a fundamental
6//! tensor transformation operation used in machine learning for preparing
7//! data for specific layer types, implementing broadcasting operations,
8//! and creating batch dimensions from single samples.
9//!
10//! # Operations
11//!
12//! * `unsqueeze()` - Add a dimension of size 1 at the specified position
13//!
14//! # Performance Characteristics
15//!
16//! * **Zero-Copy Operation**: Returns a view through reshape operation
17//! * **Memory Efficient**: Reuses existing tensor data through view operations
18//! * **Shape Expansion**: Increases tensor rank by adding singleton dimensions
19//! * **Gradient Tracking**: Full GradTrack support through reshape operations
20//! * **Edge Case Handling**: Properly handles tensors of any rank and shape
21//!
22//! # Examples
23//!
24//! ```
25//! use train_station::Tensor;
26//!
27//! // Add dimension at the beginning
28//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
29//! let unsqueezed = tensor.unsqueeze(0);
30//! assert_eq!(unsqueezed.shape().dims, vec![1, 3]);
31//! ```
32//!
33//! ```
34//! use train_station::Tensor;
35//!
36//! // Add dimension at the end
37//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
38//! let unsqueezed = tensor.unsqueeze(1);
39//! assert_eq!(unsqueezed.shape().dims, vec![3, 1]);
40//! ```
41//!
42//! # Gradient Tracking
43//!
44//! The unsqueeze operation supports automatic gradient tracking through
45//! the GradTrack system via the underlying reshape operation. When
46//! `requires_grad` is enabled, the operation preserves gradient
47//! requirements and tracking through the transformation.
48
49use crate::tensor::Tensor;
50
51impl Tensor {
52 /// Add a dimension of size 1 at the specified position
53 ///
54 /// Inserts a new dimension of size 1 at the specified position in the
55 /// tensor's shape, increasing the rank by 1 while preserving the total
56 /// number of elements. This operation is useful for preparing tensors
57 /// for broadcasting, creating batch dimensions, and adapting tensor
58 /// shapes for specific neural network operations.
59 ///
60 /// The unsqueeze operation is the inverse of `squeeze()` - unsqueezing
61 /// a dimension and then squeezing it at the same position returns the
62 /// original tensor.
63 ///
64 /// # Arguments
65 ///
66 /// * `dim` - Position to insert the new dimension (0 <= dim <= rank)
67 ///
68 /// # Returns
69 ///
70 /// A new tensor with an additional dimension of size 1 at the specified
71 /// position. The total number of elements remains unchanged.
72 ///
73 /// # Panics
74 ///
75 /// * If `dim` is out of bounds (dim > rank of the tensor)
76 ///
77 /// # Examples
78 ///
79 /// ```
80 /// use train_station::Tensor;
81 ///
82 /// // Add dimension at the beginning
83 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
84 /// let unsqueezed = tensor.unsqueeze(0);
85 /// assert_eq!(unsqueezed.shape().dims, vec![1, 3]);
86 /// assert_eq!(unsqueezed.get(&[0, 0]), 1.0);
87 /// assert_eq!(unsqueezed.get(&[0, 1]), 2.0);
88 /// assert_eq!(unsqueezed.get(&[0, 2]), 3.0);
89 /// ```
90 ///
91 /// ```
92 /// use train_station::Tensor;
93 ///
94 /// // Add dimension at the end
95 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
96 /// let unsqueezed = tensor.unsqueeze(1);
97 /// assert_eq!(unsqueezed.shape().dims, vec![3, 1]);
98 /// assert_eq!(unsqueezed.get(&[0, 0]), 1.0);
99 /// assert_eq!(unsqueezed.get(&[1, 0]), 2.0);
100 /// assert_eq!(unsqueezed.get(&[2, 0]), 3.0);
101 /// ```
102 ///
103 /// ```
104 /// use train_station::Tensor;
105 ///
106 /// // Add dimension in the middle of 2D tensor
107 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
108 /// let unsqueezed = tensor.unsqueeze(1);
109 /// assert_eq!(unsqueezed.shape().dims, vec![2, 1, 2]);
110 /// assert_eq!(unsqueezed.get(&[0, 0, 0]), 1.0);
111 /// assert_eq!(unsqueezed.get(&[0, 0, 1]), 2.0);
112 /// assert_eq!(unsqueezed.get(&[1, 0, 0]), 3.0);
113 /// assert_eq!(unsqueezed.get(&[1, 0, 1]), 4.0);
114 /// ```
115 ///
116 /// ```
117 /// use train_station::Tensor;
118 ///
119 /// // Unsqueeze preserves data integrity
120 /// let data = vec![1.0, 2.0, 3.0, 4.0];
121 /// let tensor = Tensor::from_slice(&data, vec![4]).unwrap();
122 /// let unsqueezed = tensor.unsqueeze(0);
123 /// assert_eq!(unsqueezed.shape().dims, vec![1, 4]);
124 /// assert_eq!(unsqueezed.size(), 4);
125 /// for (i, &d) in data.iter().enumerate() {
126 /// assert_eq!(unsqueezed.get(&[0, i]), d);
127 /// }
128 /// ```
129 ///
130 /// ```
131 /// use train_station::Tensor;
132 ///
133 /// // Unsqueeze with gradient tracking
134 /// let mut tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
135 /// tensor.set_requires_grad(true);
136 ///
137 /// let unsqueezed = tensor.unsqueeze(0);
138 /// assert!(unsqueezed.requires_grad());
139 /// assert_eq!(unsqueezed.shape().dims, vec![1, 3]);
140 /// ```
141 ///
142 /// ```
143 /// use train_station::Tensor;
144 ///
145 /// // Unsqueeze and squeeze roundtrip
146 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
147 /// let unsqueezed = tensor.unsqueeze(0);
148 /// assert_eq!(unsqueezed.shape().dims, vec![1, 3]);
149 ///
150 /// let squeezed = unsqueezed.squeeze(Some(0));
151 /// assert_eq!(squeezed.shape().dims, vec![3]);
152 /// assert_eq!(squeezed.get(&[0]), 1.0);
153 /// assert_eq!(squeezed.get(&[2]), 3.0);
154 /// ```
155 ///
156 /// ```
157 /// use train_station::Tensor;
158 ///
159 /// // Multiple unsqueeze operations
160 /// let tensor = Tensor::from_slice(&[42.0], vec![1]).unwrap();
161 /// let unsqueezed1 = tensor.unsqueeze(0);
162 /// assert_eq!(unsqueezed1.shape().dims, vec![1, 1]);
163 ///
164 /// let unsqueezed2 = unsqueezed1.unsqueeze(0);
165 /// assert_eq!(unsqueezed2.shape().dims, vec![1, 1, 1]);
166 /// assert_eq!(unsqueezed2.get(&[0, 0, 0]), 42.0);
167 /// ```
168 ///
169 /// # Performance
170 ///
171 /// - **Time Complexity**: O(1) - Returns a view through reshape operation
172 /// - **Memory Usage**: No additional memory allocation (view operation)
173 /// - **Gradient Tracking**: Preserves gradient requirements and tracking
174 /// - **Shape Transformation**: Increases tensor rank by adding singleton dimensions
175 ///
176 /// # Relationship to Other Operations
177 ///
178 /// This operation is related to other tensor transformations:
179 /// - `squeeze()` - Inverse operation that removes size-1 dimensions
180 /// - `reshape()` - More general shape transformation operation
181 /// - `expand()` - Broadcasts dimensions to larger sizes
182 ///
183 /// # Memory Layout
184 ///
185 /// The unsqueezed tensor maintains the same underlying data as the original
186 /// tensor through the reshape operation. This ensures zero-copy behavior
187 /// when the tensor is contiguous, with only the shape metadata being
188 /// modified to reflect the increased dimensionality.
189 ///
190 /// # Broadcasting Applications
191 ///
192 /// Unsqueeze is commonly used for broadcasting operations:
193 /// ```rust
194 /// use train_station::Tensor;
195 ///
196 /// // Prepare for broadcasting: [3] -> [1, 3] for row-wise operations
197 /// let vector = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
198 /// let row_vector = vector.unsqueeze(0); // Shape: [1, 3]
199 ///
200 /// // Prepare for broadcasting: [3] -> [3, 1] for column-wise operations
201 /// let column_vector = vector.unsqueeze(1); // Shape: [3, 1]
202 /// ```
203 ///
204 /// # Neural Network Applications
205 ///
206 /// Unsqueeze is essential for neural network operations:
207 /// ```rust
208 /// use train_station::Tensor;
209 ///
210 /// // Single sample -> batch dimension for neural network input
211 /// let sample = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
212 /// let batch = sample.unsqueeze(0); // Shape: [1, 3] for batch processing
213 ///
214 /// // Add channel dimension for convolutional operations
215 /// let feature_map = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
216 /// let with_channels = feature_map.unsqueeze(0); // Shape: [1, 2, 2] for conv layers
217 /// ```
218 #[track_caller]
219 pub fn unsqueeze(&self, dim: usize) -> Tensor {
220 let mut new_dims = self.shape().dims.clone();
221 assert!(dim <= new_dims.len(), "Dimension {} out of bounds", dim);
222 new_dims.insert(dim, 1);
223
224 // Convert to i32 for reshape call
225 let new_shape: Vec<i32> = new_dims.iter().map(|&d| d as i32).collect();
226 self.reshape(new_shape)
227 }
228}
229
230#[cfg(test)]
231mod tests {
232 use super::*;
233
234 #[test]
235 fn test_unsqueeze() {
236 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
237
238 // Unsqueeze at beginning
239 let unsqueezed = tensor.unsqueeze(0);
240 assert_eq!(unsqueezed.shape().dims, vec![1, 3]);
241
242 // Unsqueeze at end
243 let unsqueezed = tensor.unsqueeze(1);
244 assert_eq!(unsqueezed.shape().dims, vec![3, 1]);
245 }
246
247 #[test]
248 fn test_unsqueeze_2d() {
249 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
250
251 // Unsqueeze at different positions
252 let unsqueezed = tensor.unsqueeze(0);
253 assert_eq!(unsqueezed.shape().dims, vec![1, 2, 2]);
254
255 let unsqueezed = tensor.unsqueeze(1);
256 assert_eq!(unsqueezed.shape().dims, vec![2, 1, 2]);
257
258 let unsqueezed = tensor.unsqueeze(2);
259 assert_eq!(unsqueezed.shape().dims, vec![2, 2, 1]);
260 }
261
262 #[test]
263 fn test_unsqueeze_preserves_data() {
264 let data = vec![1.0, 2.0, 3.0, 4.0];
265 let tensor = Tensor::from_slice(&data, vec![4]).unwrap();
266 let unsqueezed = tensor.unsqueeze(0);
267
268 assert_eq!(unsqueezed.shape().dims, vec![1, 4]);
269 assert_eq!(unsqueezed.size(), 4);
270
271 // Verify data is preserved
272 for (i, &d) in data.iter().enumerate() {
273 assert_eq!(unsqueezed.get(&[0, i]), d);
274 }
275 }
276
277 #[test]
278 #[should_panic(expected = "Dimension 4 out of bounds")]
279 fn test_unsqueeze_out_of_bounds() {
280 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
281 tensor.unsqueeze(4); // Should panic
282 }
283
284 #[test]
285 fn test_unsqueeze_with_gradients() {
286 let mut tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
287 tensor.set_requires_grad(true);
288
289 let unsqueezed = tensor.unsqueeze(0);
290 assert!(unsqueezed.requires_grad());
291 assert_eq!(unsqueezed.shape().dims, vec![1, 3]);
292 }
293
294 #[test]
295 fn test_unsqueeze_squeeze_roundtrip() {
296 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
297 let unsqueezed = tensor.unsqueeze(0);
298 let squeezed = unsqueezed.squeeze(Some(0));
299
300 assert_eq!(squeezed.shape().dims, tensor.shape().dims);
301 assert_eq!(squeezed.get(&[0]), tensor.get(&[0]));
302 assert_eq!(squeezed.get(&[2]), tensor.get(&[2]));
303 }
304
305 #[test]
306 fn test_multiple_unsqueeze() {
307 let tensor = Tensor::from_slice(&[42.0], vec![1]).unwrap();
308 let unsqueezed1 = tensor.unsqueeze(0);
309 let unsqueezed2 = unsqueezed1.unsqueeze(0);
310
311 assert_eq!(unsqueezed2.shape().dims, vec![1, 1, 1]);
312 assert_eq!(unsqueezed2.get(&[0, 0, 0]), 42.0);
313 }
314}