train_station/tensor/transform/squeeze.rs
1//! Tensor squeeze operations
2//!
3//! This module provides tensor squeeze functionality that removes dimensions
4//! of size 1 from tensors, effectively reducing the dimensionality while
5//! preserving the total number of elements. Squeezing is a fundamental
6//! tensor transformation operation used in machine learning for cleaning
7//! up tensor shapes, preparing data for specific operations, and
8//! implementing shape normalization.
9//!
10//! # Operations
11//!
12//! * `squeeze()` - Remove dimensions of size 1 from tensor
13//!
14//! # Performance Characteristics
15//!
16//! * **Zero-Copy Operation**: Returns a view when possible, avoiding data copying
17//! * **Memory Efficient**: Reuses existing tensor data through reshape operations
18//! * **Shape Reduction**: Reduces tensor rank by removing singleton dimensions
19//! * **Gradient Tracking**: Full GradTrack support through reshape operations
20//! * **Edge Case Handling**: Properly handles tensors with all size-1 dimensions
21//!
22//! # Examples
23//!
24//! ```
25//! use train_station::Tensor;
26//!
27//! // Squeeze all size-1 dimensions
28//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![1, 3, 1]).unwrap();
29//! let squeezed = tensor.squeeze(None);
30//! assert_eq!(squeezed.shape().dims(), vec![3]);
31//! ```
32//!
33//! ```
34//! use train_station::Tensor;
35//!
36//! // Squeeze specific dimension
37//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![1, 3, 1]).unwrap();
38//! let squeezed = tensor.squeeze(Some(0));
39//! assert_eq!(squeezed.shape().dims(), vec![3, 1]);
40//! ```
41//!
42//! # Gradient Tracking
43//!
44//! The squeeze operation supports automatic gradient tracking through
45//! the GradTrack system via the underlying reshape operation. When
46//! `requires_grad` is enabled, the operation preserves gradient
47//! requirements and tracking through the transformation.
48
49use crate::tensor::Tensor;
50
51impl Tensor {
52 /// Remove dimensions of size 1 from the tensor
53 ///
54 /// Removes singleton dimensions (dimensions with size 1) from the tensor,
55 /// reducing its rank while preserving the total number of elements.
56 /// This operation is useful for cleaning up tensor shapes and preparing
57 /// data for operations that expect specific dimensionality.
58 ///
59 /// The squeeze operation can remove either all size-1 dimensions or a
60 /// specific dimension if it has size 1. When all dimensions are size 1,
61 /// the result is a scalar tensor with shape `[1]` rather than an empty
62 /// tensor to maintain mathematical consistency.
63 ///
64 /// # Arguments
65 ///
66 /// * `dim` - Optional specific dimension to squeeze. If `None`, all size-1
67 /// dimensions are removed. If `Some(d)`, only dimension `d` is
68 /// removed if it has size 1.
69 ///
70 /// # Returns
71 ///
72 /// A new tensor with size-1 dimensions removed. The total number of
73 /// elements remains unchanged.
74 ///
75 /// # Panics
76 ///
77 /// * If `dim` is specified but out of bounds for the tensor rank
78 /// * If `dim` is specified but the dimension does not have size 1
79 ///
80 /// # Examples
81 ///
82 /// ```
83 /// use train_station::Tensor;
84 ///
85 /// // Squeeze all size-1 dimensions
86 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![1, 3, 1]).unwrap();
87 /// let squeezed = tensor.squeeze(None);
88 /// assert_eq!(squeezed.shape().dims(), vec![3]);
89 /// assert_eq!(squeezed.get(&[0]), 1.0);
90 /// assert_eq!(squeezed.get(&[1]), 2.0);
91 /// assert_eq!(squeezed.get(&[2]), 3.0);
92 /// ```
93 ///
94 /// ```
95 /// use train_station::Tensor;
96 ///
97 /// // Squeeze specific dimension
98 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![1, 3, 1]).unwrap();
99 /// let squeezed = tensor.squeeze(Some(0));
100 /// assert_eq!(squeezed.shape().dims(), vec![3, 1]);
101 /// assert_eq!(squeezed.get(&[0, 0]), 1.0);
102 /// assert_eq!(squeezed.get(&[1, 0]), 2.0);
103 /// assert_eq!(squeezed.get(&[2, 0]), 3.0);
104 /// ```
105 ///
106 /// ```
107 /// use train_station::Tensor;
108 ///
109 /// // Squeeze preserves data integrity
110 /// let data = vec![1.0, 2.0, 3.0, 4.0];
111 /// let tensor = Tensor::from_slice(&data, vec![1, 2, 1, 2]).unwrap();
112 /// let squeezed = tensor.squeeze(None);
113 /// assert_eq!(squeezed.shape().dims(), vec![2, 2]);
114 /// assert_eq!(squeezed.size(), 4);
115 /// assert_eq!(squeezed.get(&[0, 0]), data[0]);
116 /// assert_eq!(squeezed.get(&[0, 1]), data[1]);
117 /// assert_eq!(squeezed.get(&[1, 0]), data[2]);
118 /// assert_eq!(squeezed.get(&[1, 1]), data[3]);
119 /// ```
120 ///
121 /// ```
122 /// use train_station::Tensor;
123 ///
124 /// // Handle edge case: all dimensions are size 1
125 /// let tensor = Tensor::from_slice(&[5.0], vec![1, 1, 1]).unwrap();
126 /// let squeezed = tensor.squeeze(None);
127 /// assert_eq!(squeezed.shape().dims(), vec![1]); // Not empty!
128 /// assert_eq!(squeezed.get(&[0]), 5.0);
129 /// ```
130 ///
131 /// ```
132 /// use train_station::Tensor;
133 ///
134 /// // Squeeze with gradient tracking
135 /// let mut tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![1, 3, 1]).unwrap();
136 /// tensor.set_requires_grad(true);
137 ///
138 /// let squeezed = tensor.squeeze(None);
139 /// assert!(squeezed.requires_grad());
140 /// assert_eq!(squeezed.shape().dims(), vec![3]);
141 /// ```
142 ///
143 /// ```
144 /// use train_station::Tensor;
145 ///
146 /// // Squeeze and unsqueeze roundtrip
147 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
148 /// let unsqueezed = tensor.unsqueeze(0);
149 /// assert_eq!(unsqueezed.shape().dims(), vec![1, 3]);
150 ///
151 /// let squeezed = unsqueezed.squeeze(Some(0));
152 /// assert_eq!(squeezed.shape().dims(), vec![3]);
153 /// assert_eq!(squeezed.get(&[0]), 1.0);
154 /// assert_eq!(squeezed.get(&[2]), 3.0);
155 /// ```
156 ///
157 /// # Performance
158 ///
159 /// - **Time Complexity**: O(1) - Returns a view through reshape operation
160 /// - **Memory Usage**: No additional memory allocation (view operation)
161 /// - **Gradient Tracking**: Preserves gradient requirements and tracking
162 /// - **Shape Transformation**: Reduces tensor rank by removing singleton dimensions
163 ///
164 /// # Relationship to Other Operations
165 ///
166 /// This operation is related to other tensor transformations:
167 /// - `unsqueeze()` - Inverse operation that adds size-1 dimensions
168 /// - `reshape()` - More general shape transformation operation
169 /// - `flatten()` - Reduces tensor to 1D by combining all dimensions
170 ///
171 /// # Memory Layout
172 ///
173 /// The squeezed tensor maintains the same underlying data as the original
174 /// tensor through the reshape operation. This ensures zero-copy behavior
175 /// when the tensor is contiguous, with only the shape metadata being
176 /// modified to reflect the reduced dimensionality.
177 ///
178 /// # Edge Cases
179 ///
180 /// - **All size-1 dimensions**: Returns a tensor with shape `[1]` rather than
181 /// an empty tensor to maintain mathematical consistency
182 /// - **No size-1 dimensions**: Returns a tensor with the same shape as the input
183 /// - **Mixed dimensions**: Only removes dimensions with size 1, preserving others
184 #[track_caller]
185 pub fn squeeze(&self, dim: Option<usize>) -> Tensor {
186 let mut new_dims = Vec::new();
187
188 if let Some(d) = dim {
189 // Squeeze specific dimension
190 assert!(d < self.shape().rank(), "Dimension {} out of bounds", d);
191 assert_eq!(
192 self.shape().dims()[d],
193 1,
194 "Cannot squeeze dimension {} with size {}",
195 d,
196 self.shape().dims()[d]
197 );
198
199 for (i, &size) in self.shape().dims().iter().enumerate() {
200 if i != d {
201 new_dims.push(size);
202 }
203 }
204 } else {
205 // Squeeze all size-1 dimensions
206 for &size in self.shape().dims() {
207 if size != 1 {
208 new_dims.push(size);
209 }
210 }
211 }
212
213 // Handle edge case where all dimensions were size 1
214 if new_dims.is_empty() {
215 new_dims.push(1);
216 }
217
218 // Convert to i32 for reshape call
219 let new_shape: Vec<i32> = new_dims.iter().map(|&d| d as i32).collect();
220 self.reshape(new_shape)
221 }
222}
223
224#[cfg(test)]
225mod tests {
226 use super::*;
227
228 #[test]
229 fn test_squeeze() {
230 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![1, 3, 1]).unwrap();
231
232 // Squeeze all size-1 dimensions
233 let squeezed = tensor.squeeze(None);
234 assert_eq!(squeezed.shape().dims(), vec![3]);
235
236 // Squeeze specific dimension
237 let squeezed = tensor.squeeze(Some(0));
238 assert_eq!(squeezed.shape().dims(), vec![3, 1]);
239 }
240
241 #[test]
242 fn test_squeeze_preserves_data() {
243 let data = vec![1.0, 2.0, 3.0, 4.0];
244 let tensor = Tensor::from_slice(&data, vec![1, 2, 1, 2]).unwrap();
245 let squeezed = tensor.squeeze(None);
246
247 assert_eq!(squeezed.shape().dims(), vec![2, 2]);
248 assert_eq!(squeezed.size(), 4);
249
250 // Verify data is preserved
251 assert_eq!(squeezed.get(&[0, 0]), data[0]);
252 assert_eq!(squeezed.get(&[0, 1]), data[1]);
253 assert_eq!(squeezed.get(&[1, 0]), data[2]);
254 assert_eq!(squeezed.get(&[1, 1]), data[3]);
255 }
256
257 #[test]
258 fn test_squeeze_all_ones() {
259 let tensor = Tensor::from_slice(&[5.0], vec![1, 1, 1]).unwrap();
260 let squeezed = tensor.squeeze(None);
261
262 // When all dimensions are 1, result should be [1] not []
263 assert_eq!(squeezed.shape().dims(), vec![1]);
264 assert_eq!(squeezed.get(&[0]), 5.0);
265 }
266
267 #[test]
268 fn test_squeeze_specific_dimension() {
269 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![1, 2, 2]).unwrap();
270
271 let squeezed = tensor.squeeze(Some(0));
272 assert_eq!(squeezed.shape().dims(), vec![2, 2]);
273 }
274
275 #[test]
276 #[should_panic(expected = "Dimension 3 out of bounds")]
277 fn test_squeeze_out_of_bounds() {
278 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
279 tensor.squeeze(Some(3)); // Should panic
280 }
281
282 #[test]
283 #[should_panic(expected = "Cannot squeeze dimension 0 with size 3")]
284 fn test_squeeze_non_unit_dimension() {
285 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
286 tensor.squeeze(Some(0)); // Should panic - dimension 0 has size 3, not 1
287 }
288
289 #[test]
290 fn test_squeeze_unsqueeze_roundtrip() {
291 // Test that squeeze and unsqueeze are inverses
292 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
293
294 let unsqueezed = tensor.unsqueeze(0);
295 assert_eq!(unsqueezed.shape().dims(), vec![1, 3]);
296
297 let squeezed = unsqueezed.squeeze(Some(0));
298 assert_eq!(squeezed.shape().dims(), vec![3]);
299
300 // Verify data integrity
301 assert_eq!(squeezed.get(&[0]), 1.0);
302 assert_eq!(squeezed.get(&[2]), 3.0);
303 }
304}