slsl/ops/
squeeze.rs

1use anyhow::Result;
2
3use crate::{Dim, Dims, Shape, StorageTrait, TensorBase, TensorView};
4
5impl<S: StorageTrait> TensorBase<S> {
6    /// Returns a new tensor with dimensions of size one removed from the specified positions.
7    ///
8    /// This function creates a new view of the tensor with dimensions of size 1 removed
9    /// from the specified positions. The returned tensor shares the same underlying data
10    /// as the original tensor, making this a zero-copy operation.
11    ///
12    /// # Arguments
13    ///
14    /// * `dims` - The dimensions to squeeze. Must be valid dimension indices for this tensor.
15    ///   - Can be a single dimension index, a slice of indices, or a range
16    ///   - Only dimensions of size 1 will be removed
17    ///   - Dimensions of size greater than 1 will be ignored
18    ///
19    /// # Returns
20    ///
21    /// A `Result<TensorView>` containing:
22    /// - `Ok(TensorView)`: A new view with the specified dimensions removed
23    /// - `Err`: If any dimension index is out of bounds
24    ///
25    /// # Examples
26    ///
27    /// ```
28    /// use slsl::Tensor;
29    ///
30    /// // 3D tensor with some dimensions of size 1
31    /// let tensor = Tensor::from_vec(vec![1, 2, 3, 4], [1, 4, 1])?;
32    ///
33    /// // Squeeze specific dimensions
34    /// let squeezed = tensor.squeeze(0)?;  // Remove dimension 0
35    /// assert_eq!(squeezed.dims(), [4, 1]);
36    ///
37    /// let squeezed = tensor.squeeze([0, 2])?;  // Remove dimensions 0 and 2
38    /// assert_eq!(squeezed.dims(), [4]);
39    ///
40    /// // Squeeze all dimensions of size 1
41    /// let squeezed = tensor.squeeze_all()?;
42    /// assert_eq!(squeezed.dims(), [4]);
43    ///
44    /// // 2D tensor with no dimensions of size 1
45    /// let tensor_2d = Tensor::from_vec(vec![1, 2, 3, 4], [2, 2])?;
46    /// let squeezed = tensor_2d.squeeze(0)?;  // No effect
47    /// assert_eq!(squeezed.dims(), [2, 2]);
48    /// # Ok::<(), Box<dyn std::error::Error>>(())
49    /// ```
50    ///
51    /// # Notes
52    ///
53    /// - This operation is memory-efficient as it returns a view rather than copying data
54    /// - Only dimensions of size 1 are removed; larger dimensions are preserved
55    /// - If all dimensions are removed, a scalar tensor (shape `[1]`) is returned
56    /// - The function follows PyTorch's `squeeze` behavior
57    /// - For out-of-bounds dimensions, the function will return an error
58    ///
59    /// # See Also
60    ///
61    /// - [`Self::unsqueeze`]: Add dimensions of size 1
62    /// - [`Self::squeeze_all`]: Remove all dimensions of size 1
63    /// - [`TensorView`]: The view type returned by this function
64    ///
65    /// [PyTorch squeeze]: https://docs.pytorch.org/docs/stable/generated/torch.squeeze.html
66    pub fn squeeze<D: Dims>(&self, dims: D) -> Result<TensorView<'_>> {
67        let dims = dims.to_dims(self.rank())?;
68
69        // Squeeze only the specified dimensions (if they are of size 1)
70        let mut new_shape_array = Shape::empty();
71        let mut new_strides_array = Shape::empty();
72        let mut count = 0usize;
73
74        for i in 0..self.rank() {
75            let should_squeeze = dims.iter().any(|&d_idx| d_idx == i);
76
77            if !should_squeeze || self.shape[i] != 1 {
78                // safe because count < rank <= 8
79                unsafe {
80                    new_shape_array.set_len(count + 1);
81                    new_strides_array.set_len(count + 1);
82                }
83                new_shape_array[count] = self.shape[i];
84                new_strides_array[count] = self.strides[i];
85                count += 1;
86            }
87        }
88
89        if count == 0 {
90            // If all dimensions were 1, create a scalar tensor
91            unsafe {
92                new_shape_array.set_len(1);
93                new_strides_array.set_len(1);
94            }
95            new_shape_array[0] = 1;
96            new_strides_array[0] = 0;
97        }
98
99        // new_shape_array/new_strides_array already populated
100
101        Ok(TensorView {
102            storage: self.storage.as_storage(),
103            ptr: self.ptr,
104            dtype: self.dtype,
105            shape: new_shape_array,
106            strides: new_strides_array,
107            offset_bytes: self.offset_bytes,
108        })
109    }
110
111    /// Returns a new tensor with all dimensions of size one removed.
112    ///
113    /// This is a convenience function that removes all dimensions of size 1 from the tensor.
114    /// It's equivalent to calling `squeeze` with all dimension indices.
115    ///
116    /// # Returns
117    ///
118    /// A `Result<TensorView>` containing:
119    /// - `Ok(TensorView)`: A new view with all size-1 dimensions removed
120    /// - `Err`: If there's an error during the squeeze operation
121    ///
122    /// # Examples
123    ///
124    /// ```
125    /// use slsl::Tensor;
126    ///
127    /// // Tensor with multiple dimensions of size 1
128    /// let tensor = Tensor::from_vec(vec![1, 2, 3, 4], [1, 4, 1, 1])?;
129    ///
130    /// // Remove all dimensions of size 1
131    /// let squeezed = tensor.squeeze_all()?;
132    /// assert_eq!(squeezed.dims(), [4]);
133    ///
134    /// // Tensor with no dimensions of size 1
135    /// let tensor_2d = Tensor::from_vec(vec![1, 2, 3, 4], [2, 2])?;
136    /// let squeezed = tensor_2d.squeeze_all()?;
137    /// assert_eq!(squeezed.dims(), [2, 2]);  // No change
138    /// # Ok::<(), Box<dyn std::error::Error>>(())
139    /// ```
140    ///
141    /// # Notes
142    ///
143    /// - This operation is memory-efficient as it returns a view rather than copying data
144    /// - If all dimensions are of size 1, a scalar tensor (shape `[1]`) is returned
145    /// - This is equivalent to `self.squeeze(0..self.rank())`
146    ///
147    /// # See Also
148    ///
149    /// - [`Self::squeeze`]: Remove specific dimensions of size 1
150    /// - [`Self::unsqueeze`]: Add dimensions of size 1
151    pub fn squeeze_all(&self) -> Result<TensorView<'_>> {
152        // Create a range of all dimensions
153        let all_dims: Vec<usize> = (0..self.rank()).collect();
154        self.squeeze(&*all_dims)
155    }
156
157    /// Returns a new tensor with a dimension of size one inserted at the specified position.
158    ///
159    /// This function creates a new view of the tensor with an additional dimension of size 1
160    /// inserted at the specified position. The returned tensor shares the same underlying data
161    /// as the original tensor, making this a zero-copy operation.
162    ///
163    /// # Arguments
164    ///
165    /// * `dim` - The position at which to insert the new dimension. Must be in the range `[-rank-1, rank]`.
166    ///   - For a 1D tensor `[4]`, valid values are `[-2, 1]`
167    ///   - For a 2D tensor `[2, 2]`, valid values are `[-3, 2]`
168    ///   - Negative indices count from the end: `-1` means the last position, `-2` means the second-to-last, etc.
169    ///
170    /// # Returns
171    ///
172    /// A `Result<TensorView>` containing:
173    /// - `Ok(TensorView)`: A new view with the inserted dimension
174    /// - `Err`: If the dimension index is out of bounds
175    ///
176    /// # Examples
177    ///
178    /// ```
179    /// use slsl::Tensor;
180    ///
181    /// // 1D tensor
182    /// let tensor = Tensor::from_vec(vec![1, 2, 3, 4], [4])?;
183    ///
184    /// // Insert at beginning (dimension 0)
185    /// let unsqueezed = tensor.unsqueeze(0)?;
186    /// assert_eq!(unsqueezed.dims(), [1, 4]);
187    ///
188    /// // Insert at end (dimension 1)
189    /// let unsqueezed = tensor.unsqueeze(1)?;
190    /// assert_eq!(unsqueezed.dims(), [4, 1]);
191    ///
192    /// // Using negative indices
193    /// let unsqueezed = tensor.unsqueeze(-1)?;  // Same as unsqueeze(1)
194    /// assert_eq!(unsqueezed.dims(), [4, 1]);
195    ///
196    /// let unsqueezed = tensor.unsqueeze(-2)?;  // Same as unsqueeze(0)
197    /// assert_eq!(unsqueezed.dims(), [1, 4]);
198    ///
199    /// // 2D tensor
200    /// let tensor_2d = Tensor::from_vec(vec![1, 2, 3, 4], [2, 2])?;
201    /// let unsqueezed = tensor_2d.unsqueeze(1)?;
202    /// assert_eq!(unsqueezed.dims(), [2, 1, 2]);
203    /// # Ok::<(), Box<dyn std::error::Error>>(())
204    /// ```
205    ///
206    /// # Notes
207    ///
208    /// - This operation is memory-efficient as it returns a view rather than copying data
209    /// - The stride for the new dimension is set to 0 since its size is 1
210    /// - The function follows PyTorch's `unsqueeze` behavior for dimension indexing
211    /// - For out-of-bounds dimensions, the function will return an error rather than silently
212    ///   inserting at the end, ensuring user intent is clear
213    ///
214    /// # See Also
215    ///
216    /// - [`Self::squeeze`]: Remove dimensions of size 1
217    /// - [`TensorView`]: The view type returned by this function
218    ///
219    /// [PyTorch unsqueeze]: https://docs.pytorch.org/docs/stable/generated/torch.unsqueeze.html
220    pub fn unsqueeze<D: Dim>(&self, dim: D) -> Result<TensorView<'_>> {
221        // For unsqueeze, dimension index can be in range [-rank-1, rank]
222        // This allows inserting at any position including at the end
223        let current_rank = self.rank();
224        let dim_idx = dim.to_dim(current_rank + 1)?;
225
226        // Create new shape/strides using Shape with size 1 inserted at the specified dimension
227        let mut new_shape_array = Shape::empty().with_len(current_rank + 1);
228        let mut new_strides_array = Shape::empty().with_len(current_rank + 1);
229
230        // Insert dimensions before the specified position
231        for i in 0..dim_idx {
232            new_shape_array[i] = self.shape[i];
233            new_strides_array[i] = self.strides[i];
234        }
235
236        // Insert the new dimension of size 1
237        new_shape_array[dim_idx] = 1;
238        new_strides_array[dim_idx] = 0; // Stride 0 for size 1 dimension
239
240        // Insert dimensions after the specified position
241        for i in dim_idx..current_rank {
242            new_shape_array[i + 1] = self.shape[i];
243            new_strides_array[i + 1] = self.strides[i];
244        }
245
246        Ok(TensorView {
247            storage: self.storage.as_storage(),
248            ptr: self.ptr,
249            dtype: self.dtype,
250            shape: new_shape_array,
251            strides: new_strides_array,
252            offset_bytes: self.offset_bytes,
253        })
254    }
255}
256
257#[cfg(test)]
258mod tests {
259    use crate::Tensor;
260
261    #[test]
262    fn squeeze_specific_dims() {
263        let t = Tensor::from_vec(vec![1u8, 2, 3, 4], [1, 4, 1]).unwrap();
264        let v = t.squeeze([0, 2]).unwrap();
265        assert_eq!(v.dims(), [4]);
266    }
267
268    #[test]
269    fn squeeze_all_dims() {
270        let t = Tensor::from_vec(vec![1u8, 2, 3, 4], [1, 4, 1]).unwrap();
271        let v = t.squeeze_all().unwrap();
272        assert_eq!(v.dims(), [4]);
273
274        let t = Tensor::rand(0., 10., [1, 2, 1, 3, 4, 1, 2, 1]).unwrap();
275        let v = t.squeeze_all().unwrap();
276        assert_eq!(v.dims(), [2, 3, 4, 2]);
277    }
278
279    #[test]
280    fn squeeze_no_effect() {
281        let t = Tensor::from_vec(vec![1u8, 2, 3, 4], [2, 2]).unwrap();
282        let v = t.squeeze(0).unwrap();
283        assert_eq!(v.dims(), [2, 2]);
284    }
285
286    #[test]
287    fn unsqueeze_all_dims() {
288        let t = Tensor::rand(0., 10., [2, 3, 4]).unwrap();
289        let v = t.unsqueeze(0).unwrap();
290        assert_eq!(v.dims(), [1, 2, 3, 4]);
291
292        let v = t.unsqueeze(1).unwrap();
293        assert_eq!(v.dims(), [2, 1, 3, 4]);
294
295        let v = t.unsqueeze(2).unwrap();
296        assert_eq!(v.dims(), [2, 3, 1, 4]);
297
298        let v = t.unsqueeze(3).unwrap();
299        assert_eq!(v.dims(), [2, 3, 4, 1]);
300
301        let v = t.unsqueeze(-1).unwrap();
302        assert_eq!(v.dims(), [2, 3, 4, 1]);
303
304        let v = t.unsqueeze(-2).unwrap();
305        assert_eq!(v.dims(), [2, 3, 1, 4]);
306    }
307}