train_station/tensor/transform/squeeze.rs
1//! Tensor squeeze operations
2//!
3//! This module provides tensor squeeze functionality that removes dimensions
4//! of size 1 from tensors, effectively reducing the dimensionality while
5//! preserving the total number of elements. Squeezing is a fundamental
6//! tensor transformation operation used in machine learning for cleaning
7//! up tensor shapes, preparing data for specific operations, and
8//! implementing shape normalization.
9//!
10//! # Operations
11//!
12//! * `squeeze()` - Remove dimensions of size 1 from tensor
13//!
14//! # Performance Characteristics
15//!
16//! * **Zero-Copy Operation**: Returns a view when possible, avoiding data copying
17//! * **Memory Efficient**: Reuses existing tensor data through reshape operations
18//! * **Shape Reduction**: Reduces tensor rank by removing singleton dimensions
19//! * **Gradient Tracking**: Full GradTrack support through reshape operations
20//! * **Edge Case Handling**: Properly handles tensors with all size-1 dimensions
21//!
22//! # Examples
23//!
24//! ```
25//! use train_station::Tensor;
26//!
27//! // Squeeze all size-1 dimensions
28//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![1, 3, 1]).unwrap();
29//! let squeezed = tensor.squeeze(None);
30//! assert_eq!(squeezed.shape().dims, vec![3]);
31//! ```
32//!
33//! ```
34//! use train_station::Tensor;
35//!
36//! // Squeeze specific dimension
37//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![1, 3, 1]).unwrap();
38//! let squeezed = tensor.squeeze(Some(0));
39//! assert_eq!(squeezed.shape().dims, vec![3, 1]);
40//! ```
41//!
42//! # Gradient Tracking
43//!
44//! The squeeze operation supports automatic gradient tracking through
45//! the GradTrack system via the underlying reshape operation. When
46//! `requires_grad` is enabled, the operation preserves gradient
47//! requirements and tracking through the transformation.
48
49use crate::tensor::Tensor;
50
51impl Tensor {
52 /// Remove dimensions of size 1 from the tensor
53 ///
54 /// Removes singleton dimensions (dimensions with size 1) from the tensor,
55 /// reducing its rank while preserving the total number of elements.
56 /// This operation is useful for cleaning up tensor shapes and preparing
57 /// data for operations that expect specific dimensionality.
58 ///
59 /// The squeeze operation can remove either all size-1 dimensions or a
60 /// specific dimension if it has size 1. When all dimensions are size 1,
61 /// the result is a scalar tensor with shape `[1]` rather than an empty
62 /// tensor to maintain mathematical consistency.
63 ///
64 /// # Arguments
65 ///
66 /// * `dim` - Optional specific dimension to squeeze. If `None`, all size-1
67 /// dimensions are removed. If `Some(d)`, only dimension `d` is
68 /// removed if it has size 1.
69 ///
70 /// # Returns
71 ///
72 /// A new tensor with size-1 dimensions removed. The total number of
73 /// elements remains unchanged.
74 ///
75 /// # Panics
76 ///
77 /// * If `dim` is specified but out of bounds for the tensor rank
78 /// * If `dim` is specified but the dimension does not have size 1
79 ///
80 /// # Examples
81 ///
82 /// ```
83 /// use train_station::Tensor;
84 ///
85 /// // Squeeze all size-1 dimensions
86 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![1, 3, 1]).unwrap();
87 /// let squeezed = tensor.squeeze(None);
88 /// assert_eq!(squeezed.shape().dims, vec![3]);
89 /// assert_eq!(squeezed.get(&[0]), 1.0);
90 /// assert_eq!(squeezed.get(&[1]), 2.0);
91 /// assert_eq!(squeezed.get(&[2]), 3.0);
92 /// ```
93 ///
94 /// ```
95 /// use train_station::Tensor;
96 ///
97 /// // Squeeze specific dimension
98 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![1, 3, 1]).unwrap();
99 /// let squeezed = tensor.squeeze(Some(0));
100 /// assert_eq!(squeezed.shape().dims, vec![3, 1]);
101 /// assert_eq!(squeezed.get(&[0, 0]), 1.0);
102 /// assert_eq!(squeezed.get(&[1, 0]), 2.0);
103 /// assert_eq!(squeezed.get(&[2, 0]), 3.0);
104 /// ```
105 ///
106 /// ```
107 /// use train_station::Tensor;
108 ///
109 /// // Squeeze preserves data integrity
110 /// let data = vec![1.0, 2.0, 3.0, 4.0];
111 /// let tensor = Tensor::from_slice(&data, vec![1, 2, 1, 2]).unwrap();
112 /// let squeezed = tensor.squeeze(None);
113 /// assert_eq!(squeezed.shape().dims, vec![2, 2]);
114 /// assert_eq!(squeezed.size(), 4);
115 /// assert_eq!(squeezed.get(&[0, 0]), data[0]);
116 /// assert_eq!(squeezed.get(&[0, 1]), data[1]);
117 /// assert_eq!(squeezed.get(&[1, 0]), data[2]);
118 /// assert_eq!(squeezed.get(&[1, 1]), data[3]);
119 /// ```
120 ///
121 /// ```
122 /// use train_station::Tensor;
123 ///
124 /// // Handle edge case: all dimensions are size 1
125 /// let tensor = Tensor::from_slice(&[5.0], vec![1, 1, 1]).unwrap();
126 /// let squeezed = tensor.squeeze(None);
127 /// assert_eq!(squeezed.shape().dims, vec![1]); // Not empty!
128 /// assert_eq!(squeezed.get(&[0]), 5.0);
129 /// ```
130 ///
131 /// ```
132 /// use train_station::Tensor;
133 ///
134 /// // Squeeze with gradient tracking
135 /// let mut tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![1, 3, 1]).unwrap();
136 /// tensor.set_requires_grad(true);
137 ///
138 /// let squeezed = tensor.squeeze(None);
139 /// assert!(squeezed.requires_grad());
140 /// assert_eq!(squeezed.shape().dims, vec![3]);
141 /// ```
142 ///
143 /// ```
144 /// use train_station::Tensor;
145 ///
146 /// // Squeeze and unsqueeze roundtrip
147 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
148 /// let unsqueezed = tensor.unsqueeze(0);
149 /// assert_eq!(unsqueezed.shape().dims, vec![1, 3]);
150 ///
151 /// let squeezed = unsqueezed.squeeze(Some(0));
152 /// assert_eq!(squeezed.shape().dims, vec![3]);
153 /// assert_eq!(squeezed.get(&[0]), 1.0);
154 /// assert_eq!(squeezed.get(&[2]), 3.0);
155 /// ```
156 ///
157 /// # Performance
158 ///
159 /// - **Time Complexity**: O(1) - Returns a view through reshape operation
160 /// - **Memory Usage**: No additional memory allocation (view operation)
161 /// - **Gradient Tracking**: Preserves gradient requirements and tracking
162 /// - **Shape Transformation**: Reduces tensor rank by removing singleton dimensions
163 ///
164 /// # Relationship to Other Operations
165 ///
166 /// This operation is related to other tensor transformations:
167 /// - `unsqueeze()` - Inverse operation that adds size-1 dimensions
168 /// - `reshape()` - More general shape transformation operation
169 /// - `flatten()` - Reduces tensor to 1D by combining all dimensions
170 ///
171 /// # Memory Layout
172 ///
173 /// The squeezed tensor maintains the same underlying data as the original
174 /// tensor through the reshape operation. This ensures zero-copy behavior
175 /// when the tensor is contiguous, with only the shape metadata being
176 /// modified to reflect the reduced dimensionality.
177 ///
178 /// # Edge Cases
179 ///
180 /// - **All size-1 dimensions**: Returns a tensor with shape `[1]` rather than
181 /// an empty tensor to maintain mathematical consistency
182 /// - **No size-1 dimensions**: Returns a tensor with the same shape as the input
183 /// - **Mixed dimensions**: Only removes dimensions with size 1, preserving others
184 pub fn squeeze(&self, dim: Option<usize>) -> Tensor {
185 let mut new_dims = Vec::new();
186
187 if let Some(d) = dim {
188 // Squeeze specific dimension
189 assert!(d < self.shape().rank(), "Dimension {} out of bounds", d);
190 assert_eq!(
191 self.shape().dims[d],
192 1,
193 "Cannot squeeze dimension {} with size {}",
194 d,
195 self.shape().dims[d]
196 );
197
198 for (i, &size) in self.shape().dims.iter().enumerate() {
199 if i != d {
200 new_dims.push(size);
201 }
202 }
203 } else {
204 // Squeeze all size-1 dimensions
205 for &size in &self.shape().dims {
206 if size != 1 {
207 new_dims.push(size);
208 }
209 }
210 }
211
212 // Handle edge case where all dimensions were size 1
213 if new_dims.is_empty() {
214 new_dims.push(1);
215 }
216
217 // Convert to i32 for reshape call
218 let new_shape: Vec<i32> = new_dims.iter().map(|&d| d as i32).collect();
219 self.reshape(new_shape)
220 }
221}
222
223#[cfg(test)]
224mod tests {
225 use super::*;
226
227 #[test]
228 fn test_squeeze() {
229 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![1, 3, 1]).unwrap();
230
231 // Squeeze all size-1 dimensions
232 let squeezed = tensor.squeeze(None);
233 assert_eq!(squeezed.shape().dims, vec![3]);
234
235 // Squeeze specific dimension
236 let squeezed = tensor.squeeze(Some(0));
237 assert_eq!(squeezed.shape().dims, vec![3, 1]);
238 }
239
240 #[test]
241 fn test_squeeze_preserves_data() {
242 let data = vec![1.0, 2.0, 3.0, 4.0];
243 let tensor = Tensor::from_slice(&data, vec![1, 2, 1, 2]).unwrap();
244 let squeezed = tensor.squeeze(None);
245
246 assert_eq!(squeezed.shape().dims, vec![2, 2]);
247 assert_eq!(squeezed.size(), 4);
248
249 // Verify data is preserved
250 assert_eq!(squeezed.get(&[0, 0]), data[0]);
251 assert_eq!(squeezed.get(&[0, 1]), data[1]);
252 assert_eq!(squeezed.get(&[1, 0]), data[2]);
253 assert_eq!(squeezed.get(&[1, 1]), data[3]);
254 }
255
256 #[test]
257 fn test_squeeze_all_ones() {
258 let tensor = Tensor::from_slice(&[5.0], vec![1, 1, 1]).unwrap();
259 let squeezed = tensor.squeeze(None);
260
261 // When all dimensions are 1, result should be [1] not []
262 assert_eq!(squeezed.shape().dims, vec![1]);
263 assert_eq!(squeezed.get(&[0]), 5.0);
264 }
265
266 #[test]
267 fn test_squeeze_specific_dimension() {
268 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![1, 2, 2]).unwrap();
269
270 let squeezed = tensor.squeeze(Some(0));
271 assert_eq!(squeezed.shape().dims, vec![2, 2]);
272 }
273
274 #[test]
275 #[should_panic(expected = "Dimension 3 out of bounds")]
276 fn test_squeeze_out_of_bounds() {
277 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
278 tensor.squeeze(Some(3)); // Should panic
279 }
280
281 #[test]
282 #[should_panic(expected = "Cannot squeeze dimension 0 with size 3")]
283 fn test_squeeze_non_unit_dimension() {
284 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
285 tensor.squeeze(Some(0)); // Should panic - dimension 0 has size 3, not 1
286 }
287
288 #[test]
289 fn test_squeeze_unsqueeze_roundtrip() {
290 // Test that squeeze and unsqueeze are inverses
291 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
292
293 let unsqueezed = tensor.unsqueeze(0);
294 assert_eq!(unsqueezed.shape().dims, vec![1, 3]);
295
296 let squeezed = unsqueezed.squeeze(Some(0));
297 assert_eq!(squeezed.shape().dims, vec![3]);
298
299 // Verify data integrity
300 assert_eq!(squeezed.get(&[0]), 1.0);
301 assert_eq!(squeezed.get(&[2]), 3.0);
302 }
303}