train_station/tensor/transform/
squeeze.rs

1//! Tensor squeeze operations
2//!
3//! This module provides tensor squeeze functionality that removes dimensions
4//! of size 1 from tensors, effectively reducing the dimensionality while
5//! preserving the total number of elements. Squeezing is a fundamental
6//! tensor transformation operation used in machine learning for cleaning
7//! up tensor shapes, preparing data for specific operations, and
8//! implementing shape normalization.
9//!
10//! # Operations
11//!
12//! * `squeeze()` - Remove dimensions of size 1 from tensor
13//!
14//! # Performance Characteristics
15//!
16//! * **Zero-Copy Operation**: Returns a view when possible, avoiding data copying
17//! * **Memory Efficient**: Reuses existing tensor data through reshape operations
18//! * **Shape Reduction**: Reduces tensor rank by removing singleton dimensions
19//! * **Gradient Tracking**: Full GradTrack support through reshape operations
20//! * **Edge Case Handling**: Properly handles tensors with all size-1 dimensions
21//!
22//! # Examples
23//!
24//! ```
25//! use train_station::Tensor;
26//!
27//! // Squeeze all size-1 dimensions
28//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![1, 3, 1]).unwrap();
29//! let squeezed = tensor.squeeze(None);
30//! assert_eq!(squeezed.shape().dims(), vec![3]);
31//! ```
32//!
33//! ```
34//! use train_station::Tensor;
35//!
36//! // Squeeze specific dimension
37//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![1, 3, 1]).unwrap();
38//! let squeezed = tensor.squeeze(Some(0));
39//! assert_eq!(squeezed.shape().dims(), vec![3, 1]);
40//! ```
41//!
42//! # Gradient Tracking
43//!
44//! The squeeze operation supports automatic gradient tracking through
45//! the GradTrack system via the underlying reshape operation. When
46//! `requires_grad` is enabled, the operation preserves gradient
47//! requirements and tracking through the transformation.
48
49use crate::tensor::Tensor;
50
51impl Tensor {
52    /// Remove dimensions of size 1 from the tensor
53    ///
54    /// Removes singleton dimensions (dimensions with size 1) from the tensor,
55    /// reducing its rank while preserving the total number of elements.
56    /// This operation is useful for cleaning up tensor shapes and preparing
57    /// data for operations that expect specific dimensionality.
58    ///
59    /// The squeeze operation can remove either all size-1 dimensions or a
60    /// specific dimension if it has size 1. When all dimensions are size 1,
61    /// the result is a scalar tensor with shape `[1]` rather than an empty
62    /// tensor to maintain mathematical consistency.
63    ///
64    /// # Arguments
65    ///
66    /// * `dim` - Optional specific dimension to squeeze. If `None`, all size-1
67    ///   dimensions are removed. If `Some(d)`, only dimension `d` is
68    ///   removed if it has size 1.
69    ///
70    /// # Returns
71    ///
72    /// A new tensor with size-1 dimensions removed. The total number of
73    /// elements remains unchanged.
74    ///
75    /// # Panics
76    ///
77    /// * If `dim` is specified but out of bounds for the tensor rank
78    /// * If `dim` is specified but the dimension does not have size 1
79    ///
80    /// # Examples
81    ///
82    /// ```
83    /// use train_station::Tensor;
84    ///
85    /// // Squeeze all size-1 dimensions
86    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![1, 3, 1]).unwrap();
87    /// let squeezed = tensor.squeeze(None);
88    /// assert_eq!(squeezed.shape().dims(), vec![3]);
89    /// assert_eq!(squeezed.get(&[0]), 1.0);
90    /// assert_eq!(squeezed.get(&[1]), 2.0);
91    /// assert_eq!(squeezed.get(&[2]), 3.0);
92    /// ```
93    ///
94    /// ```
95    /// use train_station::Tensor;
96    ///
97    /// // Squeeze specific dimension
98    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![1, 3, 1]).unwrap();
99    /// let squeezed = tensor.squeeze(Some(0));
100    /// assert_eq!(squeezed.shape().dims(), vec![3, 1]);
101    /// assert_eq!(squeezed.get(&[0, 0]), 1.0);
102    /// assert_eq!(squeezed.get(&[1, 0]), 2.0);
103    /// assert_eq!(squeezed.get(&[2, 0]), 3.0);
104    /// ```
105    ///
106    /// ```
107    /// use train_station::Tensor;
108    ///
109    /// // Squeeze preserves data integrity
110    /// let data = vec![1.0, 2.0, 3.0, 4.0];
111    /// let tensor = Tensor::from_slice(&data, vec![1, 2, 1, 2]).unwrap();
112    /// let squeezed = tensor.squeeze(None);
113    /// assert_eq!(squeezed.shape().dims(), vec![2, 2]);
114    /// assert_eq!(squeezed.size(), 4);
115    /// assert_eq!(squeezed.get(&[0, 0]), data[0]);
116    /// assert_eq!(squeezed.get(&[0, 1]), data[1]);
117    /// assert_eq!(squeezed.get(&[1, 0]), data[2]);
118    /// assert_eq!(squeezed.get(&[1, 1]), data[3]);
119    /// ```
120    ///
121    /// ```
122    /// use train_station::Tensor;
123    ///
124    /// // Handle edge case: all dimensions are size 1
125    /// let tensor = Tensor::from_slice(&[5.0], vec![1, 1, 1]).unwrap();
126    /// let squeezed = tensor.squeeze(None);
127    /// assert_eq!(squeezed.shape().dims(), vec![1]); // Not empty!
128    /// assert_eq!(squeezed.get(&[0]), 5.0);
129    /// ```
130    ///
131    /// ```
132    /// use train_station::Tensor;
133    ///
134    /// // Squeeze with gradient tracking
135    /// let mut tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![1, 3, 1]).unwrap();
136    /// tensor.set_requires_grad(true);
137    ///
138    /// let squeezed = tensor.squeeze(None);
139    /// assert!(squeezed.requires_grad());
140    /// assert_eq!(squeezed.shape().dims(), vec![3]);
141    /// ```
142    ///
143    /// ```
144    /// use train_station::Tensor;
145    ///
146    /// // Squeeze and unsqueeze roundtrip
147    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
148    /// let unsqueezed = tensor.unsqueeze(0);
149    /// assert_eq!(unsqueezed.shape().dims(), vec![1, 3]);
150    ///
151    /// let squeezed = unsqueezed.squeeze(Some(0));
152    /// assert_eq!(squeezed.shape().dims(), vec![3]);
153    /// assert_eq!(squeezed.get(&[0]), 1.0);
154    /// assert_eq!(squeezed.get(&[2]), 3.0);
155    /// ```
156    ///
157    /// # Performance
158    ///
159    /// - **Time Complexity**: O(1) - Returns a view through reshape operation
160    /// - **Memory Usage**: No additional memory allocation (view operation)
161    /// - **Gradient Tracking**: Preserves gradient requirements and tracking
162    /// - **Shape Transformation**: Reduces tensor rank by removing singleton dimensions
163    ///
164    /// # Relationship to Other Operations
165    ///
166    /// This operation is related to other tensor transformations:
167    /// - `unsqueeze()` - Inverse operation that adds size-1 dimensions
168    /// - `reshape()` - More general shape transformation operation
169    /// - `flatten()` - Reduces tensor to 1D by combining all dimensions
170    ///
171    /// # Memory Layout
172    ///
173    /// The squeezed tensor maintains the same underlying data as the original
174    /// tensor through the reshape operation. This ensures zero-copy behavior
175    /// when the tensor is contiguous, with only the shape metadata being
176    /// modified to reflect the reduced dimensionality.
177    ///
178    /// # Edge Cases
179    ///
180    /// - **All size-1 dimensions**: Returns a tensor with shape `[1]` rather than
181    ///   an empty tensor to maintain mathematical consistency
182    /// - **No size-1 dimensions**: Returns a tensor with the same shape as the input
183    /// - **Mixed dimensions**: Only removes dimensions with size 1, preserving others
184    #[track_caller]
185    pub fn squeeze(&self, dim: Option<usize>) -> Tensor {
186        let mut new_dims = Vec::new();
187
188        if let Some(d) = dim {
189            // Squeeze specific dimension
190            assert!(d < self.shape().rank(), "Dimension {} out of bounds", d);
191            assert_eq!(
192                self.shape().dims()[d],
193                1,
194                "Cannot squeeze dimension {} with size {}",
195                d,
196                self.shape().dims()[d]
197            );
198
199            for (i, &size) in self.shape().dims().iter().enumerate() {
200                if i != d {
201                    new_dims.push(size);
202                }
203            }
204        } else {
205            // Squeeze all size-1 dimensions
206            for &size in self.shape().dims() {
207                if size != 1 {
208                    new_dims.push(size);
209                }
210            }
211        }
212
213        // Handle edge case where all dimensions were size 1
214        if new_dims.is_empty() {
215            new_dims.push(1);
216        }
217
218        // Convert to i32 for reshape call
219        let new_shape: Vec<i32> = new_dims.iter().map(|&d| d as i32).collect();
220        self.reshape(new_shape)
221    }
222}
223
224#[cfg(test)]
225mod tests {
226    use super::*;
227
228    #[test]
229    fn test_squeeze() {
230        let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![1, 3, 1]).unwrap();
231
232        // Squeeze all size-1 dimensions
233        let squeezed = tensor.squeeze(None);
234        assert_eq!(squeezed.shape().dims(), vec![3]);
235
236        // Squeeze specific dimension
237        let squeezed = tensor.squeeze(Some(0));
238        assert_eq!(squeezed.shape().dims(), vec![3, 1]);
239    }
240
241    #[test]
242    fn test_squeeze_preserves_data() {
243        let data = vec![1.0, 2.0, 3.0, 4.0];
244        let tensor = Tensor::from_slice(&data, vec![1, 2, 1, 2]).unwrap();
245        let squeezed = tensor.squeeze(None);
246
247        assert_eq!(squeezed.shape().dims(), vec![2, 2]);
248        assert_eq!(squeezed.size(), 4);
249
250        // Verify data is preserved
251        assert_eq!(squeezed.get(&[0, 0]), data[0]);
252        assert_eq!(squeezed.get(&[0, 1]), data[1]);
253        assert_eq!(squeezed.get(&[1, 0]), data[2]);
254        assert_eq!(squeezed.get(&[1, 1]), data[3]);
255    }
256
257    #[test]
258    fn test_squeeze_all_ones() {
259        let tensor = Tensor::from_slice(&[5.0], vec![1, 1, 1]).unwrap();
260        let squeezed = tensor.squeeze(None);
261
262        // When all dimensions are 1, result should be [1] not []
263        assert_eq!(squeezed.shape().dims(), vec![1]);
264        assert_eq!(squeezed.get(&[0]), 5.0);
265    }
266
267    #[test]
268    fn test_squeeze_specific_dimension() {
269        let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![1, 2, 2]).unwrap();
270
271        let squeezed = tensor.squeeze(Some(0));
272        assert_eq!(squeezed.shape().dims(), vec![2, 2]);
273    }
274
275    #[test]
276    #[should_panic(expected = "Dimension 3 out of bounds")]
277    fn test_squeeze_out_of_bounds() {
278        let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
279        tensor.squeeze(Some(3)); // Should panic
280    }
281
282    #[test]
283    #[should_panic(expected = "Cannot squeeze dimension 0 with size 3")]
284    fn test_squeeze_non_unit_dimension() {
285        let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
286        tensor.squeeze(Some(0)); // Should panic - dimension 0 has size 3, not 1
287    }
288
289    #[test]
290    fn test_squeeze_unsqueeze_roundtrip() {
291        // Test that squeeze and unsqueeze are inverses
292        let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
293
294        let unsqueezed = tensor.unsqueeze(0);
295        assert_eq!(unsqueezed.shape().dims(), vec![1, 3]);
296
297        let squeezed = unsqueezed.squeeze(Some(0));
298        assert_eq!(squeezed.shape().dims(), vec![3]);
299
300        // Verify data integrity
301        assert_eq!(squeezed.get(&[0]), 1.0);
302        assert_eq!(squeezed.get(&[2]), 3.0);
303    }
304}