slsl/ops/squeeze.rs
1use anyhow::Result;
2
3use crate::{Dim, Dims, Shape, StorageTrait, TensorBase, TensorView};
4
5impl<S: StorageTrait> TensorBase<S> {
6 /// Returns a new tensor with dimensions of size one removed from the specified positions.
7 ///
8 /// This function creates a new view of the tensor with dimensions of size 1 removed
9 /// from the specified positions. The returned tensor shares the same underlying data
10 /// as the original tensor, making this a zero-copy operation.
11 ///
12 /// # Arguments
13 ///
14 /// * `dims` - The dimensions to squeeze. Must be valid dimension indices for this tensor.
15 /// - Can be a single dimension index, a slice of indices, or a range
16 /// - Only dimensions of size 1 will be removed
17 /// - Dimensions of size greater than 1 will be ignored
18 ///
19 /// # Returns
20 ///
21 /// A `Result<TensorView>` containing:
22 /// - `Ok(TensorView)`: A new view with the specified dimensions removed
23 /// - `Err`: If any dimension index is out of bounds
24 ///
25 /// # Examples
26 ///
27 /// ```
28 /// use slsl::Tensor;
29 ///
30 /// // 3D tensor with some dimensions of size 1
31 /// let tensor = Tensor::from_vec(vec![1, 2, 3, 4], [1, 4, 1])?;
32 ///
33 /// // Squeeze specific dimensions
34 /// let squeezed = tensor.squeeze(0)?; // Remove dimension 0
35 /// assert_eq!(squeezed.dims(), [4, 1]);
36 ///
37 /// let squeezed = tensor.squeeze([0, 2])?; // Remove dimensions 0 and 2
38 /// assert_eq!(squeezed.dims(), [4]);
39 ///
40 /// // Squeeze all dimensions of size 1
41 /// let squeezed = tensor.squeeze_all()?;
42 /// assert_eq!(squeezed.dims(), [4]);
43 ///
44 /// // 2D tensor with no dimensions of size 1
45 /// let tensor_2d = Tensor::from_vec(vec![1, 2, 3, 4], [2, 2])?;
46 /// let squeezed = tensor_2d.squeeze(0)?; // No effect
47 /// assert_eq!(squeezed.dims(), [2, 2]);
48 /// # Ok::<(), Box<dyn std::error::Error>>(())
49 /// ```
50 ///
51 /// # Notes
52 ///
53 /// - This operation is memory-efficient as it returns a view rather than copying data
54 /// - Only dimensions of size 1 are removed; larger dimensions are preserved
55 /// - If all dimensions are removed, a scalar tensor (shape `[1]`) is returned
56 /// - The function follows PyTorch's `squeeze` behavior
57 /// - For out-of-bounds dimensions, the function will return an error
58 ///
59 /// # See Also
60 ///
61 /// - [`Self::unsqueeze`]: Add dimensions of size 1
62 /// - [`Self::squeeze_all`]: Remove all dimensions of size 1
63 /// - [`TensorView`]: The view type returned by this function
64 ///
65 /// [PyTorch squeeze]: https://docs.pytorch.org/docs/stable/generated/torch.squeeze.html
66 pub fn squeeze<D: Dims>(&self, dims: D) -> Result<TensorView<'_>> {
67 let dims = dims.to_dims(self.rank())?;
68
69 // Squeeze only the specified dimensions (if they are of size 1)
70 let mut new_shape_array = Shape::empty();
71 let mut new_strides_array = Shape::empty();
72 let mut count = 0usize;
73
74 for i in 0..self.rank() {
75 let should_squeeze = dims.iter().any(|&d_idx| d_idx == i);
76
77 if !should_squeeze || self.shape[i] != 1 {
78 // safe because count < rank <= 8
79 unsafe {
80 new_shape_array.set_len(count + 1);
81 new_strides_array.set_len(count + 1);
82 }
83 new_shape_array[count] = self.shape[i];
84 new_strides_array[count] = self.strides[i];
85 count += 1;
86 }
87 }
88
89 if count == 0 {
90 // If all dimensions were 1, create a scalar tensor
91 unsafe {
92 new_shape_array.set_len(1);
93 new_strides_array.set_len(1);
94 }
95 new_shape_array[0] = 1;
96 new_strides_array[0] = 0;
97 }
98
99 // new_shape_array/new_strides_array already populated
100
101 Ok(TensorView {
102 storage: self.storage.as_storage(),
103 ptr: self.ptr,
104 dtype: self.dtype,
105 shape: new_shape_array,
106 strides: new_strides_array,
107 offset_bytes: self.offset_bytes,
108 })
109 }
110
111 /// Returns a new tensor with all dimensions of size one removed.
112 ///
113 /// This is a convenience function that removes all dimensions of size 1 from the tensor.
114 /// It's equivalent to calling `squeeze` with all dimension indices.
115 ///
116 /// # Returns
117 ///
118 /// A `Result<TensorView>` containing:
119 /// - `Ok(TensorView)`: A new view with all size-1 dimensions removed
120 /// - `Err`: If there's an error during the squeeze operation
121 ///
122 /// # Examples
123 ///
124 /// ```
125 /// use slsl::Tensor;
126 ///
127 /// // Tensor with multiple dimensions of size 1
128 /// let tensor = Tensor::from_vec(vec![1, 2, 3, 4], [1, 4, 1, 1])?;
129 ///
130 /// // Remove all dimensions of size 1
131 /// let squeezed = tensor.squeeze_all()?;
132 /// assert_eq!(squeezed.dims(), [4]);
133 ///
134 /// // Tensor with no dimensions of size 1
135 /// let tensor_2d = Tensor::from_vec(vec![1, 2, 3, 4], [2, 2])?;
136 /// let squeezed = tensor_2d.squeeze_all()?;
137 /// assert_eq!(squeezed.dims(), [2, 2]); // No change
138 /// # Ok::<(), Box<dyn std::error::Error>>(())
139 /// ```
140 ///
141 /// # Notes
142 ///
143 /// - This operation is memory-efficient as it returns a view rather than copying data
144 /// - If all dimensions are of size 1, a scalar tensor (shape `[1]`) is returned
145 /// - This is equivalent to `self.squeeze(0..self.rank())`
146 ///
147 /// # See Also
148 ///
149 /// - [`Self::squeeze`]: Remove specific dimensions of size 1
150 /// - [`Self::unsqueeze`]: Add dimensions of size 1
151 pub fn squeeze_all(&self) -> Result<TensorView<'_>> {
152 // Create a range of all dimensions
153 let all_dims: Vec<usize> = (0..self.rank()).collect();
154 self.squeeze(&*all_dims)
155 }
156
157 /// Returns a new tensor with a dimension of size one inserted at the specified position.
158 ///
159 /// This function creates a new view of the tensor with an additional dimension of size 1
160 /// inserted at the specified position. The returned tensor shares the same underlying data
161 /// as the original tensor, making this a zero-copy operation.
162 ///
163 /// # Arguments
164 ///
165 /// * `dim` - The position at which to insert the new dimension. Must be in the range `[-rank-1, rank]`.
166 /// - For a 1D tensor `[4]`, valid values are `[-2, 1]`
167 /// - For a 2D tensor `[2, 2]`, valid values are `[-3, 2]`
168 /// - Negative indices count from the end: `-1` means the last position, `-2` means the second-to-last, etc.
169 ///
170 /// # Returns
171 ///
172 /// A `Result<TensorView>` containing:
173 /// - `Ok(TensorView)`: A new view with the inserted dimension
174 /// - `Err`: If the dimension index is out of bounds
175 ///
176 /// # Examples
177 ///
178 /// ```
179 /// use slsl::Tensor;
180 ///
181 /// // 1D tensor
182 /// let tensor = Tensor::from_vec(vec![1, 2, 3, 4], [4])?;
183 ///
184 /// // Insert at beginning (dimension 0)
185 /// let unsqueezed = tensor.unsqueeze(0)?;
186 /// assert_eq!(unsqueezed.dims(), [1, 4]);
187 ///
188 /// // Insert at end (dimension 1)
189 /// let unsqueezed = tensor.unsqueeze(1)?;
190 /// assert_eq!(unsqueezed.dims(), [4, 1]);
191 ///
192 /// // Using negative indices
193 /// let unsqueezed = tensor.unsqueeze(-1)?; // Same as unsqueeze(1)
194 /// assert_eq!(unsqueezed.dims(), [4, 1]);
195 ///
196 /// let unsqueezed = tensor.unsqueeze(-2)?; // Same as unsqueeze(0)
197 /// assert_eq!(unsqueezed.dims(), [1, 4]);
198 ///
199 /// // 2D tensor
200 /// let tensor_2d = Tensor::from_vec(vec![1, 2, 3, 4], [2, 2])?;
201 /// let unsqueezed = tensor_2d.unsqueeze(1)?;
202 /// assert_eq!(unsqueezed.dims(), [2, 1, 2]);
203 /// # Ok::<(), Box<dyn std::error::Error>>(())
204 /// ```
205 ///
206 /// # Notes
207 ///
208 /// - This operation is memory-efficient as it returns a view rather than copying data
209 /// - The stride for the new dimension is set to 0 since its size is 1
210 /// - The function follows PyTorch's `unsqueeze` behavior for dimension indexing
211 /// - For out-of-bounds dimensions, the function will return an error rather than silently
212 /// inserting at the end, ensuring user intent is clear
213 ///
214 /// # See Also
215 ///
216 /// - [`Self::squeeze`]: Remove dimensions of size 1
217 /// - [`TensorView`]: The view type returned by this function
218 ///
219 /// [PyTorch unsqueeze]: https://docs.pytorch.org/docs/stable/generated/torch.unsqueeze.html
220 pub fn unsqueeze<D: Dim>(&self, dim: D) -> Result<TensorView<'_>> {
221 // For unsqueeze, dimension index can be in range [-rank-1, rank]
222 // This allows inserting at any position including at the end
223 let current_rank = self.rank();
224 let dim_idx = dim.to_dim(current_rank + 1)?;
225
226 // Create new shape/strides using Shape with size 1 inserted at the specified dimension
227 let mut new_shape_array = Shape::empty().with_len(current_rank + 1);
228 let mut new_strides_array = Shape::empty().with_len(current_rank + 1);
229
230 // Insert dimensions before the specified position
231 for i in 0..dim_idx {
232 new_shape_array[i] = self.shape[i];
233 new_strides_array[i] = self.strides[i];
234 }
235
236 // Insert the new dimension of size 1
237 new_shape_array[dim_idx] = 1;
238 new_strides_array[dim_idx] = 0; // Stride 0 for size 1 dimension
239
240 // Insert dimensions after the specified position
241 for i in dim_idx..current_rank {
242 new_shape_array[i + 1] = self.shape[i];
243 new_strides_array[i + 1] = self.strides[i];
244 }
245
246 Ok(TensorView {
247 storage: self.storage.as_storage(),
248 ptr: self.ptr,
249 dtype: self.dtype,
250 shape: new_shape_array,
251 strides: new_strides_array,
252 offset_bytes: self.offset_bytes,
253 })
254 }
255}
256
257#[cfg(test)]
258mod tests {
259 use crate::Tensor;
260
261 #[test]
262 fn squeeze_specific_dims() {
263 let t = Tensor::from_vec(vec![1u8, 2, 3, 4], [1, 4, 1]).unwrap();
264 let v = t.squeeze([0, 2]).unwrap();
265 assert_eq!(v.dims(), [4]);
266 }
267
268 #[test]
269 fn squeeze_all_dims() {
270 let t = Tensor::from_vec(vec![1u8, 2, 3, 4], [1, 4, 1]).unwrap();
271 let v = t.squeeze_all().unwrap();
272 assert_eq!(v.dims(), [4]);
273
274 let t = Tensor::rand(0., 10., [1, 2, 1, 3, 4, 1, 2, 1]).unwrap();
275 let v = t.squeeze_all().unwrap();
276 assert_eq!(v.dims(), [2, 3, 4, 2]);
277 }
278
279 #[test]
280 fn squeeze_no_effect() {
281 let t = Tensor::from_vec(vec![1u8, 2, 3, 4], [2, 2]).unwrap();
282 let v = t.squeeze(0).unwrap();
283 assert_eq!(v.dims(), [2, 2]);
284 }
285
286 #[test]
287 fn unsqueeze_all_dims() {
288 let t = Tensor::rand(0., 10., [2, 3, 4]).unwrap();
289 let v = t.unsqueeze(0).unwrap();
290 assert_eq!(v.dims(), [1, 2, 3, 4]);
291
292 let v = t.unsqueeze(1).unwrap();
293 assert_eq!(v.dims(), [2, 1, 3, 4]);
294
295 let v = t.unsqueeze(2).unwrap();
296 assert_eq!(v.dims(), [2, 3, 1, 4]);
297
298 let v = t.unsqueeze(3).unwrap();
299 assert_eq!(v.dims(), [2, 3, 4, 1]);
300
301 let v = t.unsqueeze(-1).unwrap();
302 assert_eq!(v.dims(), [2, 3, 4, 1]);
303
304 let v = t.unsqueeze(-2).unwrap();
305 assert_eq!(v.dims(), [2, 3, 1, 4]);
306 }
307}