train_station/tensor/transform/
unsqueeze.rs

1//! Tensor unsqueeze operations
2//!
3//! This module provides tensor unsqueeze functionality that adds dimensions
4//! of size 1 to tensors, effectively increasing the dimensionality while
5//! preserving the total number of elements. Unsqueezing is a fundamental
6//! tensor transformation operation used in machine learning for preparing
7//! data for specific layer types, implementing broadcasting operations,
8//! and creating batch dimensions from single samples.
9//!
10//! # Operations
11//!
12//! * `unsqueeze()` - Add a dimension of size 1 at the specified position
13//!
14//! # Performance Characteristics
15//!
16//! * **Zero-Copy Operation**: Returns a view through reshape operation
17//! * **Memory Efficient**: Reuses existing tensor data through view operations
18//! * **Shape Expansion**: Increases tensor rank by adding singleton dimensions
19//! * **Gradient Tracking**: Full GradTrack support through reshape operations
20//! * **Edge Case Handling**: Properly handles tensors of any rank and shape
21//!
22//! # Examples
23//!
24//! ```
25//! use train_station::Tensor;
26//!
27//! // Add dimension at the beginning
28//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
29//! let unsqueezed = tensor.unsqueeze(0);
30//! assert_eq!(unsqueezed.shape().dims, vec![1, 3]);
31//! ```
32//!
33//! ```
34//! use train_station::Tensor;
35//!
36//! // Add dimension at the end
37//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
38//! let unsqueezed = tensor.unsqueeze(1);
39//! assert_eq!(unsqueezed.shape().dims, vec![3, 1]);
40//! ```
41//!
42//! # Gradient Tracking
43//!
44//! The unsqueeze operation supports automatic gradient tracking through
45//! the GradTrack system via the underlying reshape operation. When
46//! `requires_grad` is enabled, the operation preserves gradient
47//! requirements and tracking through the transformation.
48
49use crate::tensor::Tensor;
50
51impl Tensor {
52    /// Add a dimension of size 1 at the specified position
53    ///
54    /// Inserts a new dimension of size 1 at the specified position in the
55    /// tensor's shape, increasing the rank by 1 while preserving the total
56    /// number of elements. This operation is useful for preparing tensors
57    /// for broadcasting, creating batch dimensions, and adapting tensor
58    /// shapes for specific neural network operations.
59    ///
60    /// The unsqueeze operation is the inverse of `squeeze()` - unsqueezing
61    /// a dimension and then squeezing it at the same position returns the
62    /// original tensor.
63    ///
64    /// # Arguments
65    ///
66    /// * `dim` - Position to insert the new dimension (0 <= dim <= rank)
67    ///
68    /// # Returns
69    ///
70    /// A new tensor with an additional dimension of size 1 at the specified
71    /// position. The total number of elements remains unchanged.
72    ///
73    /// # Panics
74    ///
75    /// * If `dim` is out of bounds (dim > rank of the tensor)
76    ///
77    /// # Examples
78    ///
79    /// ```
80    /// use train_station::Tensor;
81    ///
82    /// // Add dimension at the beginning
83    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
84    /// let unsqueezed = tensor.unsqueeze(0);
85    /// assert_eq!(unsqueezed.shape().dims, vec![1, 3]);
86    /// assert_eq!(unsqueezed.get(&[0, 0]), 1.0);
87    /// assert_eq!(unsqueezed.get(&[0, 1]), 2.0);
88    /// assert_eq!(unsqueezed.get(&[0, 2]), 3.0);
89    /// ```
90    ///
91    /// ```
92    /// use train_station::Tensor;
93    ///
94    /// // Add dimension at the end
95    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
96    /// let unsqueezed = tensor.unsqueeze(1);
97    /// assert_eq!(unsqueezed.shape().dims, vec![3, 1]);
98    /// assert_eq!(unsqueezed.get(&[0, 0]), 1.0);
99    /// assert_eq!(unsqueezed.get(&[1, 0]), 2.0);
100    /// assert_eq!(unsqueezed.get(&[2, 0]), 3.0);
101    /// ```
102    ///
103    /// ```
104    /// use train_station::Tensor;
105    ///
106    /// // Add dimension in the middle of 2D tensor
107    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
108    /// let unsqueezed = tensor.unsqueeze(1);
109    /// assert_eq!(unsqueezed.shape().dims, vec![2, 1, 2]);
110    /// assert_eq!(unsqueezed.get(&[0, 0, 0]), 1.0);
111    /// assert_eq!(unsqueezed.get(&[0, 0, 1]), 2.0);
112    /// assert_eq!(unsqueezed.get(&[1, 0, 0]), 3.0);
113    /// assert_eq!(unsqueezed.get(&[1, 0, 1]), 4.0);
114    /// ```
115    ///
116    /// ```
117    /// use train_station::Tensor;
118    ///
119    /// // Unsqueeze preserves data integrity
120    /// let data = vec![1.0, 2.0, 3.0, 4.0];
121    /// let tensor = Tensor::from_slice(&data, vec![4]).unwrap();
122    /// let unsqueezed = tensor.unsqueeze(0);
123    /// assert_eq!(unsqueezed.shape().dims, vec![1, 4]);
124    /// assert_eq!(unsqueezed.size(), 4);
125    /// for (i, &d) in data.iter().enumerate() {
126    ///     assert_eq!(unsqueezed.get(&[0, i]), d);
127    /// }
128    /// ```
129    ///
130    /// ```
131    /// use train_station::Tensor;
132    ///
133    /// // Unsqueeze with gradient tracking
134    /// let mut tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
135    /// tensor.set_requires_grad(true);
136    ///
137    /// let unsqueezed = tensor.unsqueeze(0);
138    /// assert!(unsqueezed.requires_grad());
139    /// assert_eq!(unsqueezed.shape().dims, vec![1, 3]);
140    /// ```
141    ///
142    /// ```
143    /// use train_station::Tensor;
144    ///
145    /// // Unsqueeze and squeeze roundtrip
146    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
147    /// let unsqueezed = tensor.unsqueeze(0);
148    /// assert_eq!(unsqueezed.shape().dims, vec![1, 3]);
149    ///
150    /// let squeezed = unsqueezed.squeeze(Some(0));
151    /// assert_eq!(squeezed.shape().dims, vec![3]);
152    /// assert_eq!(squeezed.get(&[0]), 1.0);
153    /// assert_eq!(squeezed.get(&[2]), 3.0);
154    /// ```
155    ///
156    /// ```
157    /// use train_station::Tensor;
158    ///
159    /// // Multiple unsqueeze operations
160    /// let tensor = Tensor::from_slice(&[42.0], vec![1]).unwrap();
161    /// let unsqueezed1 = tensor.unsqueeze(0);
162    /// assert_eq!(unsqueezed1.shape().dims, vec![1, 1]);
163    ///
164    /// let unsqueezed2 = unsqueezed1.unsqueeze(0);
165    /// assert_eq!(unsqueezed2.shape().dims, vec![1, 1, 1]);
166    /// assert_eq!(unsqueezed2.get(&[0, 0, 0]), 42.0);
167    /// ```
168    ///
169    /// # Performance
170    ///
171    /// - **Time Complexity**: O(1) - Returns a view through reshape operation
172    /// - **Memory Usage**: No additional memory allocation (view operation)
173    /// - **Gradient Tracking**: Preserves gradient requirements and tracking
174    /// - **Shape Transformation**: Increases tensor rank by adding singleton dimensions
175    ///
176    /// # Relationship to Other Operations
177    ///
178    /// This operation is related to other tensor transformations:
179    /// - `squeeze()` - Inverse operation that removes size-1 dimensions
180    /// - `reshape()` - More general shape transformation operation
181    /// - `expand()` - Broadcasts dimensions to larger sizes
182    ///
183    /// # Memory Layout
184    ///
185    /// The unsqueezed tensor maintains the same underlying data as the original
186    /// tensor through the reshape operation. This ensures zero-copy behavior
187    /// when the tensor is contiguous, with only the shape metadata being
188    /// modified to reflect the increased dimensionality.
189    ///
190    /// # Broadcasting Applications
191    ///
192    /// Unsqueeze is commonly used for broadcasting operations:
193    /// ```rust
194    /// use train_station::Tensor;
195    ///
196    /// // Prepare for broadcasting: [3] -> [1, 3] for row-wise operations
197    /// let vector = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
198    /// let row_vector = vector.unsqueeze(0); // Shape: [1, 3]
199    ///
200    /// // Prepare for broadcasting: [3] -> [3, 1] for column-wise operations
201    /// let column_vector = vector.unsqueeze(1); // Shape: [3, 1]
202    /// ```
203    ///
204    /// # Neural Network Applications
205    ///
206    /// Unsqueeze is essential for neural network operations:
207    /// ```rust
208    /// use train_station::Tensor;
209    ///
210    /// // Single sample -> batch dimension for neural network input
211    /// let sample = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
212    /// let batch = sample.unsqueeze(0); // Shape: [1, 3] for batch processing
213    ///
214    /// // Add channel dimension for convolutional operations
215    /// let feature_map = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
216    /// let with_channels = feature_map.unsqueeze(0); // Shape: [1, 2, 2] for conv layers
217    /// ```
218    pub fn unsqueeze(&self, dim: usize) -> Tensor {
219        let mut new_dims = self.shape().dims.clone();
220        assert!(dim <= new_dims.len(), "Dimension {} out of bounds", dim);
221        new_dims.insert(dim, 1);
222
223        // Convert to i32 for reshape call
224        let new_shape: Vec<i32> = new_dims.iter().map(|&d| d as i32).collect();
225        self.reshape(new_shape)
226    }
227}
228
229#[cfg(test)]
230mod tests {
231    use super::*;
232
233    #[test]
234    fn test_unsqueeze() {
235        let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
236
237        // Unsqueeze at beginning
238        let unsqueezed = tensor.unsqueeze(0);
239        assert_eq!(unsqueezed.shape().dims, vec![1, 3]);
240
241        // Unsqueeze at end
242        let unsqueezed = tensor.unsqueeze(1);
243        assert_eq!(unsqueezed.shape().dims, vec![3, 1]);
244    }
245
246    #[test]
247    fn test_unsqueeze_2d() {
248        let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
249
250        // Unsqueeze at different positions
251        let unsqueezed = tensor.unsqueeze(0);
252        assert_eq!(unsqueezed.shape().dims, vec![1, 2, 2]);
253
254        let unsqueezed = tensor.unsqueeze(1);
255        assert_eq!(unsqueezed.shape().dims, vec![2, 1, 2]);
256
257        let unsqueezed = tensor.unsqueeze(2);
258        assert_eq!(unsqueezed.shape().dims, vec![2, 2, 1]);
259    }
260
261    #[test]
262    fn test_unsqueeze_preserves_data() {
263        let data = vec![1.0, 2.0, 3.0, 4.0];
264        let tensor = Tensor::from_slice(&data, vec![4]).unwrap();
265        let unsqueezed = tensor.unsqueeze(0);
266
267        assert_eq!(unsqueezed.shape().dims, vec![1, 4]);
268        assert_eq!(unsqueezed.size(), 4);
269
270        // Verify data is preserved
271        for (i, &d) in data.iter().enumerate() {
272            assert_eq!(unsqueezed.get(&[0, i]), d);
273        }
274    }
275
276    #[test]
277    #[should_panic(expected = "Dimension 4 out of bounds")]
278    fn test_unsqueeze_out_of_bounds() {
279        let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
280        tensor.unsqueeze(4); // Should panic
281    }
282
283    #[test]
284    fn test_unsqueeze_with_gradients() {
285        let mut tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
286        tensor.set_requires_grad(true);
287
288        let unsqueezed = tensor.unsqueeze(0);
289        assert!(unsqueezed.requires_grad());
290        assert_eq!(unsqueezed.shape().dims, vec![1, 3]);
291    }
292
293    #[test]
294    fn test_unsqueeze_squeeze_roundtrip() {
295        let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
296        let unsqueezed = tensor.unsqueeze(0);
297        let squeezed = unsqueezed.squeeze(Some(0));
298
299        assert_eq!(squeezed.shape().dims, tensor.shape().dims);
300        assert_eq!(squeezed.get(&[0]), tensor.get(&[0]));
301        assert_eq!(squeezed.get(&[2]), tensor.get(&[2]));
302    }
303
304    #[test]
305    fn test_multiple_unsqueeze() {
306        let tensor = Tensor::from_slice(&[42.0], vec![1]).unwrap();
307        let unsqueezed1 = tensor.unsqueeze(0);
308        let unsqueezed2 = unsqueezed1.unsqueeze(0);
309
310        assert_eq!(unsqueezed2.shape().dims, vec![1, 1, 1]);
311        assert_eq!(unsqueezed2.get(&[0, 0, 0]), 42.0);
312    }
313}