Skip to main content

torsh_tensor/
tensor_utils.rs

1//! Tensor Manipulation Utilities
2//!
3//! This module provides comprehensive tensor manipulation utilities including
4//! various squeeze/unsqueeze variants, advanced transpose operations, and
5//! dimension manipulation helpers.
6//!
7//! # Features
8//!
9//! - **Smart squeezing**: Automatic removal of size-1 dimensions
10//! - **Conditional unsqueezing**: Add dimensions based on patterns
11//! - **Multi-transpose**: Transpose multiple dimensions at once
12//! - **Dimension swapping**: Flexible dimension reordering
13//! - **Shape inference**: Automatic shape calculation
14
15use torsh_core::{
16    dtype::TensorElement,
17    error::{Result, TorshError},
18};
19
20use crate::Tensor;
21
22/// Extension trait for advanced tensor manipulation
23pub trait TensorManipulationExt<T: TensorElement> {
24    /// Squeeze all dimensions of size 1
25    fn squeeze_all(&self) -> Result<Tensor<T>>;
26
27    /// Squeeze specific dimensions
28    fn squeeze_dims(&self, dims: &[i32]) -> Result<Tensor<T>>;
29
30    /// Unsqueeze at multiple positions
31    fn unsqueeze_dims(&self, dims: &[i32]) -> Result<Tensor<T>>;
32
33    /// Add a batch dimension at the front
34    fn add_batch_dim(&self) -> Result<Tensor<T>>;
35
36    /// Remove the batch dimension (first dimension)
37    fn remove_batch_dim(&self) -> Result<Tensor<T>>;
38
39    /// Ensure tensor has at least N dimensions (add trailing dimensions)
40    fn atleast_nd(&self, n: usize) -> Result<Tensor<T>>;
41
42    /// Transpose to channel-last format (NCHW -> NHWC)
43    fn to_channel_last(&self) -> Result<Tensor<T>>;
44
45    /// Transpose to channel-first format (NHWC -> NCHW)
46    fn to_channel_first(&self) -> Result<Tensor<T>>;
47
48    /// Swap two dimensions
49    fn swap_dims(&self, dim0: i32, dim1: i32) -> Result<Tensor<T>>;
50
51    /// Move a dimension to a new position
52    fn move_dim(&self, src: i32, dst: i32) -> Result<Tensor<T>>;
53
54    /// Expand singleton dimensions to match target shape
55    fn expand_to(&self, target_shape: &[usize]) -> Result<Tensor<T>>;
56
57    /// Repeat tensor along new dimension
58    fn repeat_along(&self, dim: i32, repeats: usize) -> Result<Tensor<T>>;
59}
60
61impl<T: TensorElement + Copy> TensorManipulationExt<T> for Tensor<T> {
62    fn squeeze_all(&self) -> Result<Tensor<T>> {
63        let shape_binding = self.shape();
64        let shape = shape_binding.dims();
65        let new_shape: Vec<usize> = shape.iter().filter(|&&s| s != 1).copied().collect();
66
67        if new_shape.is_empty() {
68            // All dimensions were 1, create scalar
69            self.reshape(&[1])
70        } else {
71            let new_shape_i32: Vec<i32> = new_shape.iter().map(|&s| s as i32).collect();
72            self.reshape(&new_shape_i32)
73        }
74    }
75
76    fn squeeze_dims(&self, dims: &[i32]) -> Result<Tensor<T>> {
77        let shape_binding = self.shape();
78        let shape = shape_binding.dims();
79        let ndim = shape.len() as i32;
80
81        // Normalize dimensions
82        let normalized_dims: Result<Vec<usize>> = dims
83            .iter()
84            .map(|&d| {
85                let normalized = if d < 0 { ndim + d } else { d };
86                if normalized < 0 || normalized >= ndim {
87                    Err(TorshError::InvalidArgument(format!(
88                        "Dimension {} out of range for tensor with {} dimensions",
89                        d, ndim
90                    )))
91                } else {
92                    Ok(normalized as usize)
93                }
94            })
95            .collect();
96
97        let normalized_dims = normalized_dims?;
98
99        // Check that specified dimensions are size 1
100        for &dim in &normalized_dims {
101            if shape[dim] != 1 {
102                return Err(TorshError::InvalidArgument(format!(
103                    "Cannot squeeze dimension {} of size {}",
104                    dim, shape[dim]
105                )));
106            }
107        }
108
109        // Build new shape without squeezed dimensions
110        let new_shape: Vec<usize> = shape
111            .iter()
112            .enumerate()
113            .filter(|(i, _)| !normalized_dims.contains(i))
114            .map(|(_, &s)| s)
115            .collect();
116
117        if new_shape.is_empty() {
118            self.reshape(&[1])
119        } else {
120            let new_shape_i32: Vec<i32> = new_shape.iter().map(|&s| s as i32).collect();
121            self.reshape(&new_shape_i32)
122        }
123    }
124
125    fn unsqueeze_dims(&self, dims: &[i32]) -> Result<Tensor<T>> {
126        let mut result = self.clone();
127
128        // Sort dimensions to handle in ascending order
129        let mut sorted_dims: Vec<i32> = dims.to_vec();
130        sorted_dims.sort_unstable();
131
132        // Process in order, adjusting subsequent dims for already-inserted dimensions
133        for (i, &dim) in sorted_dims.iter().enumerate() {
134            // Adjust for previously inserted dimensions
135            let adjusted_dim = dim + i as i32;
136            result = result.unsqueeze(adjusted_dim)?;
137        }
138
139        Ok(result)
140    }
141
142    fn add_batch_dim(&self) -> Result<Tensor<T>> {
143        self.unsqueeze(0)
144    }
145
146    fn remove_batch_dim(&self) -> Result<Tensor<T>> {
147        let shape_binding = self.shape();
148        let shape = shape_binding.dims();
149        if shape.is_empty() {
150            return Err(TorshError::InvalidArgument(
151                "Cannot remove batch dim from scalar tensor".to_string(),
152            ));
153        }
154
155        if shape[0] != 1 {
156            return Err(TorshError::InvalidArgument(format!(
157                "Batch dimension has size {}, expected 1",
158                shape[0]
159            )));
160        }
161
162        self.squeeze(0)
163    }
164
165    fn atleast_nd(&self, n: usize) -> Result<Tensor<T>> {
166        let shape_binding = self.shape();
167        let shape = shape_binding.dims();
168        let current_ndim = shape.len();
169
170        if current_ndim >= n {
171            return Ok(self.clone());
172        }
173
174        let mut new_shape = shape.to_vec();
175        for _ in current_ndim..n {
176            new_shape.push(1);
177        }
178
179        let new_shape_i32: Vec<i32> = new_shape.iter().map(|&s| s as i32).collect();
180        self.reshape(&new_shape_i32)
181    }
182
183    fn to_channel_last(&self) -> Result<Tensor<T>> {
184        let shape_binding = self.shape();
185        let shape = shape_binding.dims();
186
187        match shape.len() {
188            4 => {
189                // NCHW -> NHWC
190                self.permute(&[0, 2, 3, 1])
191            }
192            3 => {
193                // CHW -> HWC
194                self.permute(&[1, 2, 0])
195            }
196            _ => Err(TorshError::InvalidArgument(
197                "to_channel_last expects 3D or 4D tensor".to_string(),
198            )),
199        }
200    }
201
202    fn to_channel_first(&self) -> Result<Tensor<T>> {
203        let shape_binding = self.shape();
204        let shape = shape_binding.dims();
205
206        match shape.len() {
207            4 => {
208                // NHWC -> NCHW
209                self.permute(&[0, 3, 1, 2])
210            }
211            3 => {
212                // HWC -> CHW
213                self.permute(&[2, 0, 1])
214            }
215            _ => Err(TorshError::InvalidArgument(
216                "to_channel_first expects 3D or 4D tensor".to_string(),
217            )),
218        }
219    }
220
221    fn swap_dims(&self, dim0: i32, dim1: i32) -> Result<Tensor<T>> {
222        self.transpose(dim0, dim1)
223    }
224
225    fn move_dim(&self, src: i32, dst: i32) -> Result<Tensor<T>> {
226        let ndim = self.shape().dims().len() as i32;
227
228        // Normalize dimensions
229        let src = if src < 0 { ndim + src } else { src };
230        let dst = if dst < 0 { ndim + dst } else { dst };
231
232        if src < 0 || src >= ndim || dst < 0 || dst >= ndim {
233            return Err(TorshError::InvalidArgument(
234                "Dimension out of range".to_string(),
235            ));
236        }
237
238        if src == dst {
239            return Ok(self.clone());
240        }
241
242        // Build permutation to move dimension
243        let mut perm: Vec<i32> = (0..ndim).collect();
244        let src_dim = perm.remove(src as usize);
245
246        perm.insert(dst as usize, src_dim);
247
248        self.permute(&perm)
249    }
250
251    fn expand_to(&self, target_shape: &[usize]) -> Result<Tensor<T>> {
252        let shape_binding = self.shape();
253        let current_shape = shape_binding.dims();
254
255        if current_shape.len() > target_shape.len() {
256            return Err(TorshError::InvalidArgument(
257                "Cannot expand to shape with fewer dimensions".to_string(),
258            ));
259        }
260
261        // Check compatibility
262        for (i, &current_size) in current_shape.iter().rev().enumerate() {
263            let target_idx = target_shape.len() - 1 - i;
264            let target_size = target_shape[target_idx];
265
266            if current_size != 1 && current_size != target_size {
267                return Err(TorshError::InvalidArgument(format!(
268                    "Cannot expand dimension {} from {} to {}",
269                    target_idx, current_size, target_size
270                )));
271            }
272        }
273
274        self.expand(target_shape)
275    }
276
277    fn repeat_along(&self, dim: i32, repeats: usize) -> Result<Tensor<T>> {
278        // First unsqueeze at dim, then repeat
279        let unsqueezed = self.unsqueeze(dim)?;
280        let shape_binding = unsqueezed.shape();
281        let shape = shape_binding.dims();
282
283        let mut repeat_shape = vec![1; shape.len()];
284        let normalized_dim = if dim < 0 {
285            (shape.len() as i32 + dim) as usize
286        } else {
287            dim as usize
288        };
289
290        repeat_shape[normalized_dim] = repeats;
291
292        unsqueezed.repeat(&repeat_shape)
293    }
294}
295
296/// Helper functions for shape manipulation
297pub mod shape_utils {
298    use super::*;
299
300    /// Calculate the number of elements in a shape
301    pub fn numel(shape: &[usize]) -> usize {
302        shape.iter().product()
303    }
304
305    /// Check if two shapes are broadcast-compatible
306    pub fn are_broadcastable(shape1: &[usize], shape2: &[usize]) -> bool {
307        let len1 = shape1.len();
308        let len2 = shape2.len();
309        let max_len = len1.max(len2);
310
311        for i in 0..max_len {
312            let dim1 = if i < len1 { shape1[len1 - 1 - i] } else { 1 };
313
314            let dim2 = if i < len2 { shape2[len2 - 1 - i] } else { 1 };
315
316            if dim1 != 1 && dim2 != 1 && dim1 != dim2 {
317                return false;
318            }
319        }
320
321        true
322    }
323
324    /// Calculate broadcast shape
325    pub fn broadcast_shape(shape1: &[usize], shape2: &[usize]) -> Option<Vec<usize>> {
326        if !are_broadcastable(shape1, shape2) {
327            return None;
328        }
329
330        let len1 = shape1.len();
331        let len2 = shape2.len();
332        let max_len = len1.max(len2);
333
334        let mut result = Vec::with_capacity(max_len);
335
336        for i in 0..max_len {
337            let dim1 = if i < len1 { shape1[len1 - 1 - i] } else { 1 };
338
339            let dim2 = if i < len2 { shape2[len2 - 1 - i] } else { 1 };
340
341            result.push(dim1.max(dim2));
342        }
343
344        result.reverse();
345        Some(result)
346    }
347
348    /// Infer shape with -1 (unknown dimension)
349    pub fn infer_shape(shape: &[i32], total_elements: usize) -> Result<Vec<usize>> {
350        let mut result = Vec::new();
351        let mut unknown_idx = None;
352        let mut known_product = 1usize;
353
354        for (i, &dim) in shape.iter().enumerate() {
355            if dim == -1 {
356                if unknown_idx.is_some() {
357                    return Err(TorshError::InvalidArgument(
358                        "Only one dimension can be inferred".to_string(),
359                    ));
360                }
361                unknown_idx = Some(i);
362                result.push(0); // Placeholder
363            } else if dim < 0 {
364                return Err(TorshError::InvalidArgument(format!(
365                    "Invalid dimension size: {}",
366                    dim
367                )));
368            } else {
369                result.push(dim as usize);
370                known_product *= dim as usize;
371            }
372        }
373
374        if let Some(idx) = unknown_idx {
375            if known_product == 0 {
376                return Err(TorshError::InvalidArgument(
377                    "Cannot infer dimension with zero-sized dimensions".to_string(),
378                ));
379            }
380
381            if total_elements % known_product != 0 {
382                return Err(TorshError::InvalidArgument(
383                    "Cannot infer dimension: size is not divisible".to_string(),
384                ));
385            }
386
387            result[idx] = total_elements / known_product;
388        }
389
390        Ok(result)
391    }
392}
393
394#[cfg(test)]
395mod tests {
396    use super::*;
397    use crate::creation::*;
398
399    #[test]
400    fn test_squeeze_all() {
401        let tensor = zeros::<f32>(&[1, 3, 1, 4, 1]).expect("zeros creation should succeed");
402        let squeezed = tensor.squeeze_all().expect("squeeze_all should succeed");
403
404        assert_eq!(squeezed.shape().dims(), &[3, 4]);
405    }
406
407    #[test]
408    fn test_squeeze_dims() {
409        let tensor = zeros::<f32>(&[1, 3, 1, 4]).expect("zeros creation should succeed");
410        let squeezed = tensor
411            .squeeze_dims(&[0, 2])
412            .expect("squeeze_dims should succeed");
413
414        assert_eq!(squeezed.shape().dims(), &[3, 4]);
415    }
416
417    #[test]
418    fn test_unsqueeze_dims() {
419        let tensor = zeros::<f32>(&[3, 4]).expect("zeros creation should succeed");
420        // unsqueeze at 0: [3, 4] -> [1, 3, 4]
421        // unsqueeze at 2+1=3 (adjusted): [1, 3, 4] -> [1, 3, 4, 1]
422        let unsqueezed = tensor
423            .unsqueeze_dims(&[0, 2])
424            .expect("unsqueeze_dims should succeed");
425
426        assert_eq!(unsqueezed.shape().dims(), &[1, 3, 4, 1]);
427    }
428
429    #[test]
430    fn test_add_remove_batch_dim() {
431        let tensor = zeros::<f32>(&[3, 4]).expect("zeros creation should succeed");
432        let with_batch = tensor
433            .add_batch_dim()
434            .expect("add_batch_dim should succeed");
435
436        assert_eq!(with_batch.shape().dims(), &[1, 3, 4]);
437
438        let without_batch = with_batch
439            .remove_batch_dim()
440            .expect("remove_batch_dim should succeed");
441        assert_eq!(without_batch.shape().dims(), &[3, 4]);
442    }
443
444    #[test]
445    fn test_atleast_nd() {
446        let tensor = zeros::<f32>(&[3, 4]).expect("zeros creation should succeed");
447        let expanded = tensor.atleast_nd(4).expect("atleast_nd should succeed");
448
449        assert_eq!(expanded.shape().dims(), &[3, 4, 1, 1]);
450    }
451
452    #[test]
453    fn test_channel_conversions() {
454        let tensor = zeros::<f32>(&[2, 3, 4, 5]).expect("zeros creation should succeed"); // NCHW
455
456        let channel_last = tensor
457            .to_channel_last()
458            .expect("channel conversion should succeed");
459        assert_eq!(channel_last.shape().dims(), &[2, 4, 5, 3]); // NHWC
460
461        let channel_first = channel_last
462            .to_channel_first()
463            .expect("channel conversion should succeed");
464        assert_eq!(channel_first.shape().dims(), &[2, 3, 4, 5]); // Back to NCHW
465    }
466
467    #[test]
468    fn test_move_dim() {
469        let tensor = zeros::<f32>(&[2, 3, 4, 5]).expect("zeros creation should succeed");
470        let moved = tensor.move_dim(1, 3).expect("move_dim should succeed");
471
472        // Move dimension 1 to position 3
473        assert_eq!(moved.shape().dims(), &[2, 4, 5, 3]);
474    }
475
476    #[test]
477    fn test_shape_utils_broadcastable() {
478        use shape_utils::*;
479
480        assert!(are_broadcastable(&[3, 1, 4], &[1, 5, 4]));
481        assert!(are_broadcastable(&[3, 4], &[3, 4]));
482        assert!(are_broadcastable(&[1], &[3, 4]));
483
484        assert!(!are_broadcastable(&[3, 4], &[2, 4]));
485    }
486
487    #[test]
488    fn test_shape_utils_broadcast_shape() {
489        use shape_utils::*;
490
491        let result = broadcast_shape(&[3, 1, 4], &[1, 5, 4]);
492        assert_eq!(result, Some(vec![3, 5, 4]));
493
494        let result = broadcast_shape(&[3, 4], &[2, 4]);
495        assert_eq!(result, None);
496    }
497
498    #[test]
499    fn test_shape_utils_infer_shape() {
500        use shape_utils::*;
501
502        let inferred = infer_shape(&[2, -1, 3], 24).expect("shape inference should succeed");
503        assert_eq!(inferred, vec![2, 4, 3]);
504
505        let inferred = infer_shape(&[3, 4], 12).expect("shape inference should succeed");
506        assert_eq!(inferred, vec![3, 4]);
507    }
508
509    #[test]
510    fn test_squeeze_dims_invalid() {
511        let tensor = zeros::<f32>(&[3, 4]).expect("zeros creation should succeed");
512        let result = tensor.squeeze_dims(&[0]); // Dimension 0 has size 3, not 1
513
514        assert!(result.is_err());
515    }
516
517    #[test]
518    fn test_remove_batch_dim_invalid() {
519        let tensor = zeros::<f32>(&[3, 4]).expect("zeros creation should succeed");
520        let result = tensor.remove_batch_dim(); // First dim has size 3, not 1
521
522        assert!(result.is_err());
523    }
524}