train_station/tensor/transform/
squeeze.rs

1//! Tensor squeeze operations
2//!
3//! This module provides tensor squeeze functionality that removes dimensions
4//! of size 1 from tensors, effectively reducing the dimensionality while
5//! preserving the total number of elements. Squeezing is a fundamental
6//! tensor transformation operation used in machine learning for cleaning
7//! up tensor shapes, preparing data for specific operations, and
8//! implementing shape normalization.
9//!
10//! # Operations
11//!
12//! * `squeeze()` - Remove dimensions of size 1 from tensor
13//!
14//! # Performance Characteristics
15//!
16//! * **Zero-Copy Operation**: Returns a view when possible, avoiding data copying
17//! * **Memory Efficient**: Reuses existing tensor data through reshape operations
18//! * **Shape Reduction**: Reduces tensor rank by removing singleton dimensions
19//! * **Gradient Tracking**: Full GradTrack support through reshape operations
20//! * **Edge Case Handling**: Properly handles tensors with all size-1 dimensions
21//!
22//! # Examples
23//!
24//! ```
25//! use train_station::Tensor;
26//!
27//! // Squeeze all size-1 dimensions
28//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![1, 3, 1]).unwrap();
29//! let squeezed = tensor.squeeze(None);
30//! assert_eq!(squeezed.shape().dims, vec![3]);
31//! ```
32//!
33//! ```
34//! use train_station::Tensor;
35//!
36//! // Squeeze specific dimension
37//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![1, 3, 1]).unwrap();
38//! let squeezed = tensor.squeeze(Some(0));
39//! assert_eq!(squeezed.shape().dims, vec![3, 1]);
40//! ```
41//!
42//! # Gradient Tracking
43//!
44//! The squeeze operation supports automatic gradient tracking through
45//! the GradTrack system via the underlying reshape operation. When
46//! `requires_grad` is enabled, the operation preserves gradient
47//! requirements and tracking through the transformation.
48
49use crate::tensor::Tensor;
50
51impl Tensor {
52    /// Remove dimensions of size 1 from the tensor
53    ///
54    /// Removes singleton dimensions (dimensions with size 1) from the tensor,
55    /// reducing its rank while preserving the total number of elements.
56    /// This operation is useful for cleaning up tensor shapes and preparing
57    /// data for operations that expect specific dimensionality.
58    ///
59    /// The squeeze operation can remove either all size-1 dimensions or a
60    /// specific dimension if it has size 1. When all dimensions are size 1,
61    /// the result is a scalar tensor with shape `[1]` rather than an empty
62    /// tensor to maintain mathematical consistency.
63    ///
64    /// # Arguments
65    ///
66    /// * `dim` - Optional specific dimension to squeeze. If `None`, all size-1
67    ///   dimensions are removed. If `Some(d)`, only dimension `d` is
68    ///   removed if it has size 1.
69    ///
70    /// # Returns
71    ///
72    /// A new tensor with size-1 dimensions removed. The total number of
73    /// elements remains unchanged.
74    ///
75    /// # Panics
76    ///
77    /// * If `dim` is specified but out of bounds for the tensor rank
78    /// * If `dim` is specified but the dimension does not have size 1
79    ///
80    /// # Examples
81    ///
82    /// ```
83    /// use train_station::Tensor;
84    ///
85    /// // Squeeze all size-1 dimensions
86    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![1, 3, 1]).unwrap();
87    /// let squeezed = tensor.squeeze(None);
88    /// assert_eq!(squeezed.shape().dims, vec![3]);
89    /// assert_eq!(squeezed.get(&[0]), 1.0);
90    /// assert_eq!(squeezed.get(&[1]), 2.0);
91    /// assert_eq!(squeezed.get(&[2]), 3.0);
92    /// ```
93    ///
94    /// ```
95    /// use train_station::Tensor;
96    ///
97    /// // Squeeze specific dimension
98    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![1, 3, 1]).unwrap();
99    /// let squeezed = tensor.squeeze(Some(0));
100    /// assert_eq!(squeezed.shape().dims, vec![3, 1]);
101    /// assert_eq!(squeezed.get(&[0, 0]), 1.0);
102    /// assert_eq!(squeezed.get(&[1, 0]), 2.0);
103    /// assert_eq!(squeezed.get(&[2, 0]), 3.0);
104    /// ```
105    ///
106    /// ```
107    /// use train_station::Tensor;
108    ///
109    /// // Squeeze preserves data integrity
110    /// let data = vec![1.0, 2.0, 3.0, 4.0];
111    /// let tensor = Tensor::from_slice(&data, vec![1, 2, 1, 2]).unwrap();
112    /// let squeezed = tensor.squeeze(None);
113    /// assert_eq!(squeezed.shape().dims, vec![2, 2]);
114    /// assert_eq!(squeezed.size(), 4);
115    /// assert_eq!(squeezed.get(&[0, 0]), data[0]);
116    /// assert_eq!(squeezed.get(&[0, 1]), data[1]);
117    /// assert_eq!(squeezed.get(&[1, 0]), data[2]);
118    /// assert_eq!(squeezed.get(&[1, 1]), data[3]);
119    /// ```
120    ///
121    /// ```
122    /// use train_station::Tensor;
123    ///
124    /// // Handle edge case: all dimensions are size 1
125    /// let tensor = Tensor::from_slice(&[5.0], vec![1, 1, 1]).unwrap();
126    /// let squeezed = tensor.squeeze(None);
127    /// assert_eq!(squeezed.shape().dims, vec![1]); // Not empty!
128    /// assert_eq!(squeezed.get(&[0]), 5.0);
129    /// ```
130    ///
131    /// ```
132    /// use train_station::Tensor;
133    ///
134    /// // Squeeze with gradient tracking
135    /// let mut tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![1, 3, 1]).unwrap();
136    /// tensor.set_requires_grad(true);
137    ///
138    /// let squeezed = tensor.squeeze(None);
139    /// assert!(squeezed.requires_grad());
140    /// assert_eq!(squeezed.shape().dims, vec![3]);
141    /// ```
142    ///
143    /// ```
144    /// use train_station::Tensor;
145    ///
146    /// // Squeeze and unsqueeze roundtrip
147    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
148    /// let unsqueezed = tensor.unsqueeze(0);
149    /// assert_eq!(unsqueezed.shape().dims, vec![1, 3]);
150    ///
151    /// let squeezed = unsqueezed.squeeze(Some(0));
152    /// assert_eq!(squeezed.shape().dims, vec![3]);
153    /// assert_eq!(squeezed.get(&[0]), 1.0);
154    /// assert_eq!(squeezed.get(&[2]), 3.0);
155    /// ```
156    ///
157    /// # Performance
158    ///
159    /// - **Time Complexity**: O(1) - Returns a view through reshape operation
160    /// - **Memory Usage**: No additional memory allocation (view operation)
161    /// - **Gradient Tracking**: Preserves gradient requirements and tracking
162    /// - **Shape Transformation**: Reduces tensor rank by removing singleton dimensions
163    ///
164    /// # Relationship to Other Operations
165    ///
166    /// This operation is related to other tensor transformations:
167    /// - `unsqueeze()` - Inverse operation that adds size-1 dimensions
168    /// - `reshape()` - More general shape transformation operation
169    /// - `flatten()` - Reduces tensor to 1D by combining all dimensions
170    ///
171    /// # Memory Layout
172    ///
173    /// The squeezed tensor maintains the same underlying data as the original
174    /// tensor through the reshape operation. This ensures zero-copy behavior
175    /// when the tensor is contiguous, with only the shape metadata being
176    /// modified to reflect the reduced dimensionality.
177    ///
178    /// # Edge Cases
179    ///
180    /// - **All size-1 dimensions**: Returns a tensor with shape `[1]` rather than
181    ///   an empty tensor to maintain mathematical consistency
182    /// - **No size-1 dimensions**: Returns a tensor with the same shape as the input
183    /// - **Mixed dimensions**: Only removes dimensions with size 1, preserving others
184    pub fn squeeze(&self, dim: Option<usize>) -> Tensor {
185        let mut new_dims = Vec::new();
186
187        if let Some(d) = dim {
188            // Squeeze specific dimension
189            assert!(d < self.shape().rank(), "Dimension {} out of bounds", d);
190            assert_eq!(
191                self.shape().dims[d],
192                1,
193                "Cannot squeeze dimension {} with size {}",
194                d,
195                self.shape().dims[d]
196            );
197
198            for (i, &size) in self.shape().dims.iter().enumerate() {
199                if i != d {
200                    new_dims.push(size);
201                }
202            }
203        } else {
204            // Squeeze all size-1 dimensions
205            for &size in &self.shape().dims {
206                if size != 1 {
207                    new_dims.push(size);
208                }
209            }
210        }
211
212        // Handle edge case where all dimensions were size 1
213        if new_dims.is_empty() {
214            new_dims.push(1);
215        }
216
217        // Convert to i32 for reshape call
218        let new_shape: Vec<i32> = new_dims.iter().map(|&d| d as i32).collect();
219        self.reshape(new_shape)
220    }
221}
222
223#[cfg(test)]
224mod tests {
225    use super::*;
226
227    #[test]
228    fn test_squeeze() {
229        let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![1, 3, 1]).unwrap();
230
231        // Squeeze all size-1 dimensions
232        let squeezed = tensor.squeeze(None);
233        assert_eq!(squeezed.shape().dims, vec![3]);
234
235        // Squeeze specific dimension
236        let squeezed = tensor.squeeze(Some(0));
237        assert_eq!(squeezed.shape().dims, vec![3, 1]);
238    }
239
240    #[test]
241    fn test_squeeze_preserves_data() {
242        let data = vec![1.0, 2.0, 3.0, 4.0];
243        let tensor = Tensor::from_slice(&data, vec![1, 2, 1, 2]).unwrap();
244        let squeezed = tensor.squeeze(None);
245
246        assert_eq!(squeezed.shape().dims, vec![2, 2]);
247        assert_eq!(squeezed.size(), 4);
248
249        // Verify data is preserved
250        assert_eq!(squeezed.get(&[0, 0]), data[0]);
251        assert_eq!(squeezed.get(&[0, 1]), data[1]);
252        assert_eq!(squeezed.get(&[1, 0]), data[2]);
253        assert_eq!(squeezed.get(&[1, 1]), data[3]);
254    }
255
256    #[test]
257    fn test_squeeze_all_ones() {
258        let tensor = Tensor::from_slice(&[5.0], vec![1, 1, 1]).unwrap();
259        let squeezed = tensor.squeeze(None);
260
261        // When all dimensions are 1, result should be [1] not []
262        assert_eq!(squeezed.shape().dims, vec![1]);
263        assert_eq!(squeezed.get(&[0]), 5.0);
264    }
265
266    #[test]
267    fn test_squeeze_specific_dimension() {
268        let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![1, 2, 2]).unwrap();
269
270        let squeezed = tensor.squeeze(Some(0));
271        assert_eq!(squeezed.shape().dims, vec![2, 2]);
272    }
273
274    #[test]
275    #[should_panic(expected = "Dimension 3 out of bounds")]
276    fn test_squeeze_out_of_bounds() {
277        let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
278        tensor.squeeze(Some(3)); // Should panic
279    }
280
281    #[test]
282    #[should_panic(expected = "Cannot squeeze dimension 0 with size 3")]
283    fn test_squeeze_non_unit_dimension() {
284        let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
285        tensor.squeeze(Some(0)); // Should panic - dimension 0 has size 3, not 1
286    }
287
288    #[test]
289    fn test_squeeze_unsqueeze_roundtrip() {
290        // Test that squeeze and unsqueeze are inverses
291        let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
292
293        let unsqueezed = tensor.unsqueeze(0);
294        assert_eq!(unsqueezed.shape().dims, vec![1, 3]);
295
296        let squeezed = unsqueezed.squeeze(Some(0));
297        assert_eq!(squeezed.shape().dims, vec![3]);
298
299        // Verify data integrity
300        assert_eq!(squeezed.get(&[0]), 1.0);
301        assert_eq!(squeezed.get(&[2]), 3.0);
302    }
303}