Skip to main content

torsh_tensor/
shape_ops.rs

1//! Shape and view operations for tensors
2//!
3//! This module provides comprehensive tensor shape manipulation and view operations
4//! including reshaping, transposing, slicing, squeezing, unsqueezing, and permuting.
5//!
6//! # Features
7//!
8//! - **Zero-copy views**: Efficient view operations that share underlying data
9//! - **Safe reshaping**: Comprehensive validation and overflow checking
10//! - **Dimension manipulation**: Squeeze, unsqueeze, transpose, and permute operations
11//! - **Slicing operations**: Flexible tensor slicing with stride computation
12//! - **Broadcasting support**: Expand operations for broadcasting compatibility
13//! - **Contiguity checking**: Efficient memory layout validation
14
15use std::sync::{Arc, RwLock};
16use torsh_core::{
17    dtype::TensorElement,
18    error::{Result, TorshError},
19    shape::Shape,
20};
21
22use crate::core_ops::{Operation, Tensor};
23
24impl<T: TensorElement + Copy> Tensor<T> {
25    /// Get size of a specific dimension
26    pub fn size(&self, dim: i32) -> Result<usize> {
27        self.shape().size(dim)
28    }
29
30    /// Reshapes the tensor to a new shape (creates a view or copy if needed).
31    ///
32    /// This is equivalent to PyTorch's `view()` operation. The total number of elements
33    /// must remain the same. You can use `-1` for one dimension to have it inferred automatically.
34    ///
35    /// # Arguments
36    ///
37    /// * `shape` - The new shape as a slice of dimensions. Use `-1` to infer one dimension.
38    ///
39    /// # Returns
40    ///
41    /// A reshaped tensor, or an error if the reshape is invalid.
42    ///
43    /// # Examples
44    ///
45    /// ```
46    /// use torsh_tensor::creation::zeros;
47    ///
48    /// // Reshape a 1D tensor to 2D
49    /// let t = zeros::<f32>(&[6]).expect("tensor creation should succeed");
50    /// let reshaped = t.view(&[2, 3]).expect("view should succeed");
51    /// assert_eq!(reshaped.shape().dims(), &[2, 3]);
52    ///
53    /// // Use -1 to infer a dimension
54    /// let t2 = zeros::<f32>(&[12]).expect("tensor creation should succeed");
55    /// let auto = t2.view(&[-1, 4]).expect("view should succeed");  // Infers 3 for first dimension
56    /// assert_eq!(auto.shape().dims(), &[3, 4]);
57    ///
58    /// // Flatten to 1D
59    /// let matrix = zeros::<f32>(&[3, 4, 5]).expect("tensor creation should succeed");
60    /// let flat = matrix.view(&[-1]).expect("view should succeed");
61    /// assert_eq!(flat.shape().dims(), &[60]);
62    /// ```
63    ///
64    /// # Errors
65    ///
66    /// Returns an error if:
67    /// - More than one dimension is `-1`
68    /// - The total number of elements doesn't match
69    /// - Any dimension would overflow
70    ///
71    /// # See Also
72    ///
73    /// * [`Self::reshape`] - Alias for `view()`
74    /// * [`Self::view_as`] - Zero-copy view for compatible shapes
75    /// * [`Self::contiguous`] - Make tensor contiguous in memory
76    pub fn view(&self, shape: &[i32]) -> Result<Self> {
77        // Validate that there's at most one -1 in the shape
78        let infer_count = shape.iter().filter(|&&x| x == -1).count();
79        if infer_count > 1 {
80            return Err(TorshError::InvalidShape(
81                "Only one dimension can be inferred (only one -1 allowed)".to_string(),
82            ));
83        }
84
85        let new_shape: Result<Vec<usize>> = shape
86            .iter()
87            .map(|&d| {
88                if d == -1 {
89                    // Infer dimension - first validate all other dimensions are valid
90                    let known_dims: Result<Vec<usize>> = shape
91                        .iter()
92                        .filter(|&&x| x != -1)
93                        .map(|&x| {
94                            if x < 0 {
95                                Err(TorshError::InvalidShape(format!(
96                                    "Invalid dimension size: {x} (negative dimensions not allowed except -1)"
97                                )))
98                            } else {
99                                Ok(x as usize)
100                            }
101                        })
102                        .collect();
103
104                    let known_dims = known_dims?;
105
106                    // Check for overflow in product calculation
107                    let known_product = known_dims.iter().try_fold(1usize, |acc, &dim| {
108                        acc.checked_mul(dim).ok_or_else(|| {
109                            TorshError::InvalidShape(
110                                "Shape dimensions too large (would overflow)".to_string()
111                            )
112                        })
113                    })?;
114
115                    if known_product == 0 {
116                        return Err(TorshError::InvalidShape(
117                            "Cannot infer dimension with zero-sized dimensions".to_string(),
118                        ));
119                    }
120
121                    let total = self.numel();
122                    if total % known_product != 0 {
123                        return Err(TorshError::InvalidShape(
124                            "Cannot infer dimension: size is not divisible".to_string(),
125                        ));
126                    }
127
128                    Ok(total / known_product)
129                } else if d < 0 {
130                    Err(TorshError::InvalidShape(format!(
131                        "Invalid dimension size: {d}"
132                    )))
133                } else {
134                    Ok(d as usize)
135                }
136            })
137            .collect();
138
139        let new_shape = new_shape?;
140
141        // Check for overflow in total elements calculation
142        let new_numel = new_shape.iter().try_fold(1usize, |acc, &dim| {
143            acc.checked_mul(dim).ok_or_else(|| {
144                TorshError::InvalidShape(
145                    "Reshaped tensor would be too large (would overflow)".to_string(),
146                )
147            })
148        })?;
149
150        if new_numel != self.numel() {
151            return Err(TorshError::InvalidShape(format!(
152                "Shape {:?} is invalid for tensor of size {}",
153                new_shape,
154                self.numel()
155            )));
156        }
157
158        // Create a new tensor with the same data but different shape
159        let data = self.to_vec()?;
160        Self::from_data(data, new_shape, self.device)
161    }
162
163    /// Create an efficient view with different shape (shares data, no copying)
164    /// This is the zero-copy version of view() for compatible shapes
165    pub fn view_as(&self, shape: &[usize]) -> Result<Self> {
166        // Validate that the total number of elements is the same
167        let new_numel = shape.iter().product::<usize>();
168        if new_numel != self.numel() {
169            return Err(TorshError::InvalidShape(format!(
170                "Shape {:?} is invalid for tensor of size {}",
171                shape,
172                self.numel()
173            )));
174        }
175
176        // Only create efficient views for contiguous tensors or existing views
177        // that are still relatively simple
178        if !self.is_contiguous() {
179            return Err(TorshError::InvalidShape(
180                "Cannot create efficient view of non-contiguous tensor".to_string(),
181            ));
182        }
183
184        // Create new tensor sharing the same storage
185        Ok(Self {
186            storage: self.storage.clone(),
187            shape: Shape::new(shape.to_vec()),
188            device: self.device,
189            requires_grad: self.requires_grad,
190            grad: Arc::new(RwLock::new(None)), // Views don't share gradients
191            operation: Operation::Leaf,        // Views reset operation tracking
192            strides: None,                     // Use default contiguous strides for simple reshapes
193            storage_offset: self.storage_offset,
194            base_tensor: if self.is_view() {
195                // If this is already a view, keep reference to the original base
196                self.base_tensor.clone()
197            } else {
198                // This is a base tensor, so create a weak reference to it
199                Some(Arc::downgrade(&Arc::new(self.clone())))
200            },
201        })
202    }
203
204    /// Create a view of a slice along a dimension (shares data, no copying)
205    pub fn slice_tensor(&self, dim: usize, start: usize, end: usize) -> Result<Self> {
206        if dim >= self.ndim() {
207            return Err(TorshError::InvalidArgument(format!(
208                "Dimension {} out of range for tensor with {} dimensions",
209                dim,
210                self.ndim()
211            )));
212        }
213
214        let shape = self.shape.dims();
215        if start >= shape[dim] || end > shape[dim] || start >= end {
216            return Err(TorshError::InvalidArgument(format!(
217                "Invalid slice range [{}:{}] for dimension {} of size {}",
218                start, end, dim, shape[dim]
219            )));
220        }
221
222        // Calculate new shape
223        let mut new_shape = shape.to_vec();
224        new_shape[dim] = end - start;
225
226        // Calculate new strides and offset
227        let current_strides = self.strides();
228        let offset_adjustment = start * current_strides[dim];
229
230        Ok(Self {
231            storage: self.storage.clone(),
232            shape: Shape::new(new_shape),
233            device: self.device,
234            requires_grad: self.requires_grad,
235            grad: Arc::new(RwLock::new(None)),
236            operation: Operation::Leaf,
237            strides: Some(current_strides),
238            storage_offset: self.storage_offset + offset_adjustment,
239            base_tensor: if self.is_view() {
240                self.base_tensor.clone()
241            } else {
242                Some(Arc::downgrade(&Arc::new(self.clone())))
243            },
244        })
245    }
246
247    /// Create a transposed view (shares data, no copying)
248    pub fn transpose_view(&self, dim0: usize, dim1: usize) -> Result<Self> {
249        if dim0 >= self.ndim() || dim1 >= self.ndim() {
250            return Err(TorshError::InvalidArgument(format!(
251                "Dimensions {} and {} out of range for tensor with {} dimensions",
252                dim0,
253                dim1,
254                self.ndim()
255            )));
256        }
257
258        if dim0 == dim1 {
259            return Ok(self.clone());
260        }
261
262        // Create new shape and strides
263        let mut new_shape = self.shape.dims().to_vec();
264        let mut new_strides = self.strides();
265
266        // Swap dimensions
267        new_shape.swap(dim0, dim1);
268        new_strides.swap(dim0, dim1);
269
270        Ok(Self {
271            storage: self.storage.clone(),
272            shape: Shape::new(new_shape),
273            device: self.device,
274            requires_grad: self.requires_grad,
275            grad: Arc::new(RwLock::new(None)),
276            operation: Operation::Leaf,
277            strides: Some(new_strides),
278            storage_offset: self.storage_offset,
279            base_tensor: if self.is_view() {
280                self.base_tensor.clone()
281            } else {
282                Some(Arc::downgrade(&Arc::new(self.clone())))
283            },
284        })
285    }
286
287    /// Squeeze a tensor along a specific dimension (removes dimension of size 1)
288    pub fn squeeze_tensor(&self, dim: usize) -> Result<Self> {
289        if dim >= self.ndim() {
290            return Err(TorshError::InvalidArgument(format!(
291                "Dimension {} out of range for tensor with {} dimensions",
292                dim,
293                self.ndim()
294            )));
295        }
296
297        let shape = self.shape.dims();
298        if shape[dim] != 1 {
299            return Err(TorshError::InvalidArgument(format!(
300                "Cannot squeeze dimension {} of size {}",
301                dim, shape[dim]
302            )));
303        }
304
305        // Remove the dimension from shape and strides
306        let mut new_shape = shape.to_vec();
307        new_shape.remove(dim);
308
309        let mut new_strides = self.strides();
310        new_strides.remove(dim);
311
312        Ok(Self {
313            storage: self.storage.clone(),
314            shape: Shape::new(new_shape),
315            device: self.device,
316            requires_grad: self.requires_grad,
317            grad: Arc::new(RwLock::new(None)),
318            operation: Operation::Leaf,
319            strides: Some(new_strides),
320            storage_offset: self.storage_offset,
321            base_tensor: if self.is_view() {
322                self.base_tensor.clone()
323            } else {
324                Some(Arc::downgrade(&Arc::new(self.clone())))
325            },
326        })
327    }
328
329    /// Unsqueeze a tensor at a specific dimension (adds dimension of size 1)
330    pub fn unsqueeze_tensor(&self, dim: usize) -> Result<Self> {
331        if dim > self.ndim() {
332            return Err(TorshError::InvalidArgument(format!(
333                "Dimension {} out of range for insertion in tensor with {} dimensions",
334                dim,
335                self.ndim()
336            )));
337        }
338
339        // Insert new dimension into shape and strides
340        let mut new_shape = self.shape.dims().to_vec();
341        new_shape.insert(dim, 1);
342
343        let mut new_strides = self.strides();
344        // For the new dimension, stride should be the product of all dimensions to the right
345        let new_stride = if dim == new_shape.len() - 1 {
346            1 // Last dimension always has stride 1
347        } else {
348            new_strides[dim] // Use the stride that was at this position
349        };
350        new_strides.insert(dim, new_stride);
351
352        Ok(Self {
353            storage: self.storage.clone(),
354            shape: Shape::new(new_shape),
355            device: self.device,
356            requires_grad: self.requires_grad,
357            grad: Arc::new(RwLock::new(None)),
358            operation: Operation::Leaf,
359            strides: Some(new_strides),
360            storage_offset: self.storage_offset,
361            base_tensor: if self.is_view() {
362                self.base_tensor.clone()
363            } else {
364                Some(Arc::downgrade(&Arc::new(self.clone())))
365            },
366        })
367    }
368
369    /// Transposes two dimensions of the tensor.
370    ///
371    /// Swaps the specified dimensions, creating a new tensor. For 2D tensors, calling
372    /// `transpose(0, 1)` produces the standard matrix transpose operation.
373    ///
374    /// # Arguments
375    ///
376    /// * `dim0` - The first dimension to swap. Negative values count from the end.
377    /// * `dim1` - The second dimension to swap. Negative values count from the end.
378    ///
379    /// # Returns
380    ///
381    /// A tensor with the specified dimensions transposed.
382    ///
383    /// # Examples
384    ///
385    /// ```
386    /// use torsh_tensor::creation::{zeros, arange};
387    ///
388    /// // Standard matrix transpose
389    /// let matrix = zeros::<f32>(&[3, 4]).expect("tensor creation should succeed");
390    /// let transposed = matrix.transpose(0, 1).expect("transpose should succeed");
391    /// assert_eq!(transposed.shape().dims(), &[4, 3]);
392    ///
393    /// // Transpose in 3D tensor
394    /// let cube = zeros::<f32>(&[2, 3, 4]).expect("tensor creation should succeed");
395    /// let swapped = cube.transpose(0, 2).expect("transpose should succeed");
396    /// assert_eq!(swapped.shape().dims(), &[4, 3, 2]);
397    ///
398    /// // Use negative indexing
399    /// let t = zeros::<f32>(&[5, 6, 7]).expect("tensor creation should succeed");
400    /// let result = t.transpose(-2, -1).expect("transpose should succeed");
401    /// assert_eq!(result.shape().dims(), &[5, 7, 6]);
402    ///
403    /// // Practical use: convert between row-major and column-major
404    /// let data = arange(0, 12, 1).expect("arange should succeed");
405    /// let row_major = data.reshape(&[3, 4]).expect("reshape should succeed");
406    /// let col_major = row_major.transpose(0, 1).expect("transpose should succeed");
407    /// ```
408    ///
409    /// # See Also
410    ///
411    /// * [`Self::permute`] - Rearrange dimensions in arbitrary order
412    /// * [`Self::view`] - Reshape to different dimensions
413    pub fn transpose(&self, dim0: i32, dim1: i32) -> Result<Self> {
414        let ndim = self.ndim();
415        let dim0 = if dim0 < 0 {
416            (ndim as i32 + dim0) as usize
417        } else {
418            dim0 as usize
419        };
420        let dim1 = if dim1 < 0 {
421            (ndim as i32 + dim1) as usize
422        } else {
423            dim1 as usize
424        };
425
426        if dim0 >= ndim || dim1 >= ndim {
427            return Err(TorshError::InvalidArgument(format!(
428                "Dimensions {} and {} out of range for tensor with {} dimensions",
429                dim0, dim1, ndim
430            )));
431        }
432
433        if ndim == 2 && dim0 != dim1 {
434            self.transpose_2d()
435        } else {
436            self.transpose_view(dim0, dim1)
437        }
438    }
439
440    /// 2D transpose implementation
441    fn transpose_2d(&self) -> Result<Self> {
442        let shape = self.shape.dims();
443        if shape.len() != 2 {
444            return Err(TorshError::InvalidArgument(
445                "transpose_2d only works with 2D tensors".to_string(),
446            ));
447        }
448
449        let (rows, cols) = (shape[0], shape[1]);
450        let data = self.to_vec()?;
451        let mut transposed_data = Vec::with_capacity(data.len());
452
453        for col in 0..cols {
454            for row in 0..rows {
455                transposed_data.push(data[row * cols + col]);
456            }
457        }
458
459        Self::from_data(transposed_data, vec![cols, rows], self.device)
460    }
461
462    /// Permute dimensions according to the given order
463    pub fn permute(&self, dims: &[i32]) -> Result<Self> {
464        let ndim = self.ndim();
465
466        if dims.len() != ndim {
467            return Err(TorshError::InvalidArgument(format!(
468                "Number of dimensions in permutation ({}) doesn't match tensor dimensions ({})",
469                dims.len(),
470                ndim
471            )));
472        }
473
474        // Convert negative indices and validate
475        let perm_dims: Result<Vec<usize>> = dims
476            .iter()
477            .map(|&d| {
478                let dim = if d < 0 { ndim as i32 + d } else { d } as usize;
479                if dim >= ndim {
480                    Err(TorshError::InvalidArgument(format!(
481                        "Dimension {} out of range for tensor with {} dimensions",
482                        d, ndim
483                    )))
484                } else {
485                    Ok(dim)
486                }
487            })
488            .collect();
489
490        let perm_dims = perm_dims?;
491
492        // Check for duplicates
493        let mut sorted_dims = perm_dims.clone();
494        sorted_dims.sort_unstable();
495        for i in 0..ndim {
496            if sorted_dims[i] != i {
497                return Err(TorshError::InvalidArgument(
498                    "Permutation must contain each dimension exactly once".to_string(),
499                ));
500            }
501        }
502
503        // Create new shape and strides
504        let old_shape = self.shape.dims();
505        let old_strides = self.strides();
506
507        let new_shape: Vec<usize> = perm_dims.iter().map(|&i| old_shape[i]).collect();
508        let new_strides: Vec<usize> = perm_dims.iter().map(|&i| old_strides[i]).collect();
509
510        Ok(Self {
511            storage: self.storage.clone(),
512            shape: Shape::new(new_shape),
513            device: self.device,
514            requires_grad: self.requires_grad,
515            grad: Arc::new(RwLock::new(None)),
516            operation: Operation::Leaf,
517            strides: Some(new_strides),
518            storage_offset: self.storage_offset,
519            base_tensor: if self.is_view() {
520                self.base_tensor.clone()
521            } else {
522                Some(Arc::downgrade(&Arc::new(self.clone())))
523            },
524        })
525    }
526
527    /// Removes a dimension of size 1 at the specified position.
528    ///
529    /// This operation reduces the dimensionality of the tensor by removing dimensions
530    /// that have size 1. Commonly used to remove singleton dimensions after reductions
531    /// or to match tensor shapes for operations.
532    ///
533    /// # Arguments
534    ///
535    /// * `dim` - The dimension to squeeze. Negative values count from the end.
536    ///
537    /// # Returns
538    ///
539    /// A tensor with the specified dimension removed, or an error if the dimension
540    /// doesn't have size 1.
541    ///
542    /// # Examples
543    ///
544    /// ```
545    /// use torsh_tensor::creation::zeros;
546    ///
547    /// // Remove a singleton dimension
548    /// let t = zeros::<f32>(&[3, 1, 4]).expect("tensor creation should succeed");
549    /// let squeezed = t.squeeze(1).expect("squeeze should succeed");
550    /// assert_eq!(squeezed.shape().dims(), &[3, 4]);
551    ///
552    /// // Use negative indexing
553    /// let t2 = zeros::<f32>(&[2, 3, 1]).expect("tensor creation should succeed");
554    /// let squeezed2 = t2.squeeze(-1).expect("squeeze should succeed");
555    /// assert_eq!(squeezed2.shape().dims(), &[2, 3]);
556    ///
557    /// // After a reduction operation
558    /// let matrix = zeros::<f32>(&[5, 10]).expect("tensor creation should succeed");
559    /// let reduced = matrix.sum_dim(&[1], true).expect("sum_dim should succeed");  // Shape: [5, 1]
560    /// let final_result = reduced.squeeze(1).expect("squeeze should succeed");  // Shape: [5]
561    /// ```
562    ///
563    /// # See Also
564    ///
565    /// * [`Self::squeeze_all`] - Remove all singleton dimensions
566    /// * [`Self::unsqueeze`] - Add a singleton dimension
567    pub fn squeeze(&self, dim: i32) -> Result<Self> {
568        let ndim = self.ndim();
569        let dim = if dim < 0 {
570            (ndim as i32 + dim) as usize
571        } else {
572            dim as usize
573        };
574
575        self.squeeze_tensor(dim)
576    }
577
578    /// Squeeze all dimensions with size 1
579    pub fn squeeze_all(&self) -> Result<Self> {
580        let shape = self.shape.dims();
581        let new_shape: Vec<usize> = shape.iter().copied().filter(|&s| s != 1).collect();
582
583        if new_shape.is_empty() {
584            // If all dimensions were 1, result should be a scalar (0-dimensional tensor)
585            let data = self.to_vec()?;
586            Self::from_data(data, vec![], self.device)
587        } else {
588            let data = self.to_vec()?;
589            Self::from_data(data, new_shape, self.device)
590        }
591    }
592
593    /// Adds a dimension of size 1 at the specified position.
594    ///
595    /// This operation increases the dimensionality of the tensor by inserting a new
596    /// dimension of size 1. Commonly used to add batch dimensions or to match tensor
597    /// shapes for broadcasting operations.
598    ///
599    /// # Arguments
600    ///
601    /// * `dim` - The position to insert the new dimension. Negative values count from the end.
602    ///
603    /// # Returns
604    ///
605    /// A tensor with an additional dimension of size 1 inserted.
606    ///
607    /// # Examples
608    ///
609    /// ```
610    /// use torsh_tensor::creation::zeros;
611    ///
612    /// // Add a batch dimension at the beginning
613    /// let t = zeros::<f32>(&[3, 4]).expect("tensor creation should succeed");
614    /// let batched = t.unsqueeze(0).expect("unsqueeze should succeed");
615    /// assert_eq!(batched.shape().dims(), &[1, 3, 4]);
616    ///
617    /// // Add a dimension at the end
618    /// let t2 = zeros::<f32>(&[5]).expect("tensor creation should succeed");
619    /// let expanded = t2.unsqueeze(-1).expect("unsqueeze should succeed");
620    /// assert_eq!(expanded.shape().dims(), &[5, 1]);
621    ///
622    /// // Prepare for broadcasting
623    /// let weights = zeros::<f32>(&[64]).expect("tensor creation should succeed");
624    /// let weights_2d = weights.unsqueeze(0).expect("unsqueeze should succeed");  // Shape: [1, 64]
625    /// // Now can broadcast with shape [batch_size, 64]
626    /// ```
627    ///
628    /// # See Also
629    ///
630    /// * [`Self::squeeze`] - Remove a singleton dimension
631    /// * [`Self::view`] - Reshape to arbitrary shape
632    pub fn unsqueeze(&self, dim: i32) -> Result<Self> {
633        let ndim = self.ndim();
634        let dim = if dim < 0 {
635            (ndim as i32 + dim + 1) as usize
636        } else {
637            dim as usize
638        };
639
640        self.unsqueeze_tensor(dim)
641    }
642
643    /// Reshapes the tensor to a new shape.
644    ///
645    /// This is an alias for [`view()`](Self::view) and provides the same functionality.
646    /// The total number of elements must remain the same.
647    ///
648    /// # Arguments
649    ///
650    /// * `shape` - The new shape as a slice of dimensions. Use `-1` to infer one dimension.
651    ///
652    /// # Returns
653    ///
654    /// A reshaped tensor, or an error if the reshape is invalid.
655    ///
656    /// # Examples
657    ///
658    /// ```
659    /// use torsh_tensor::creation::arange;
660    ///
661    /// // Reshape a sequence to a matrix
662    /// let t = arange(0, 12, 1).expect("arange should succeed");
663    /// let matrix = t.reshape(&[3, 4]).expect("reshape should succeed");
664    /// assert_eq!(matrix.shape().dims(), &[3, 4]);
665    ///
666    /// // Reshape with automatic dimension inference
667    /// let cube = t.reshape(&[2, -1, 3]).expect("reshape should succeed");  // Infers 2 for middle dimension
668    /// assert_eq!(cube.shape().dims(), &[2, 2, 3]);
669    /// ```
670    ///
671    /// # See Also
672    ///
673    /// * [`Self::view`] - The underlying implementation
674    pub fn reshape(&self, shape: &[i32]) -> Result<Self> {
675        self.view(shape)
676    }
677
678    /// Check if tensor is contiguous in memory
679    pub fn is_contiguous(&self) -> bool {
680        // A tensor is contiguous if its strides match the default strides for its shape
681        let default_strides = self.compute_default_strides();
682        let current_strides = self.strides();
683
684        current_strides == default_strides
685    }
686
687    /// Make tensor contiguous if it isn't already
688    pub fn contiguous(&self) -> Result<Self> {
689        if self.is_contiguous() {
690            Ok(self.clone())
691        } else {
692            // Need to copy data to make it contiguous
693            let data = self.to_vec()?;
694            Self::from_data(data, self.shape.dims().to_vec(), self.device)
695        }
696    }
697
698    /// Expand tensor to a larger size
699    pub fn expand(&self, shape: &[usize]) -> Result<Self> {
700        let old_shape = self.shape.dims();
701
702        // Validate that expansion is possible
703        if shape.len() < old_shape.len() {
704            return Err(TorshError::InvalidShape(
705                "Cannot expand to smaller number of dimensions".to_string(),
706            ));
707        }
708
709        // Check dimension compatibility (broadcasting rules)
710        let offset = shape.len() - old_shape.len();
711        for (i, &old_dim) in old_shape.iter().enumerate() {
712            let new_dim = shape[offset + i];
713            if old_dim != 1 && old_dim != new_dim {
714                return Err(TorshError::InvalidShape(format!(
715                    "Cannot expand dimension {} from {} to {}",
716                    i, old_dim, new_dim
717                )));
718            }
719        }
720
721        // For now, implement expansion by copying data
722        // TODO: Implement efficient expansion with strided views
723        let source_data = self.to_vec()?;
724        let target_numel = shape.iter().product();
725        let mut result_data = Vec::with_capacity(target_numel);
726
727        self.expand_data_recursive(&source_data, &mut result_data, shape, old_shape, 0, 0)?;
728
729        Self::from_data(result_data, shape.to_vec(), self.device)
730    }
731
732    /// Helper for recursive data expansion
733    fn expand_data_recursive(
734        &self,
735        source: &[T],
736        dest: &mut Vec<T>,
737        target_shape: &[usize],
738        source_shape: &[usize],
739        target_dim: usize,
740        source_offset: usize,
741    ) -> Result<()> {
742        if target_dim == target_shape.len() {
743            // Base case: copy single element
744            dest.push(source[source_offset]);
745            return Ok(());
746        }
747
748        let target_size = target_shape[target_dim];
749        let source_dim_idx = target_dim + source_shape.len() - target_shape.len();
750
751        if source_dim_idx < source_shape.len() {
752            let source_size = source_shape[source_dim_idx];
753            let stride = if source_dim_idx + 1 < source_shape.len() {
754                source_shape[source_dim_idx + 1..].iter().product()
755            } else {
756                1
757            };
758
759            if source_size == 1 {
760                // Broadcast this dimension
761                for _ in 0..target_size {
762                    self.expand_data_recursive(
763                        source,
764                        dest,
765                        target_shape,
766                        source_shape,
767                        target_dim + 1,
768                        source_offset,
769                    )?;
770                }
771            } else {
772                // Copy along this dimension
773                for i in 0..target_size {
774                    self.expand_data_recursive(
775                        source,
776                        dest,
777                        target_shape,
778                        source_shape,
779                        target_dim + 1,
780                        source_offset + i * stride,
781                    )?;
782                }
783            }
784        } else {
785            // This is a new dimension, repeat the entire subtensor
786            for _ in 0..target_size {
787                self.expand_data_recursive(
788                    source,
789                    dest,
790                    target_shape,
791                    source_shape,
792                    target_dim + 1,
793                    source_offset,
794                )?;
795            }
796        }
797
798        Ok(())
799    }
800
801    /// Move dimensions from source positions to destination positions
802    ///
803    /// # PyTorch Compatibility
804    /// Equivalent to `torch.movedim(tensor, source, destination)`
805    ///
806    /// # Arguments
807    /// * `source` - Original positions of dimensions to move
808    /// * `destination` - Target positions for the dimensions
809    ///
810    /// # Examples
811    /// ```ignore
812    /// let x = Tensor::from_data(vec![1.0; 24], vec![2, 3, 4], DeviceType::Cpu)?;
813    /// let y = x.movedim(&[0, 1], &[2, 0])?; // [2,3,4] -> [3,4,2]
814    /// ```
815    pub fn movedim(&self, source: &[isize], destination: &[isize]) -> Result<Self> {
816        if source.len() != destination.len() {
817            return Err(TorshError::InvalidArgument(
818                "source and destination must have the same length".to_string(),
819            ));
820        }
821
822        let ndim = self.ndim();
823
824        // Normalize source and destination dimensions
825        let norm_source: Result<Vec<usize>> = source
826            .iter()
827            .map(|&d| {
828                let dim = if d < 0 {
829                    (ndim as isize + d) as usize
830                } else {
831                    d as usize
832                };
833                if dim >= ndim {
834                    Err(TorshError::InvalidArgument(format!(
835                        "Dimension {} out of range for {}-D tensor",
836                        d, ndim
837                    )))
838                } else {
839                    Ok(dim)
840                }
841            })
842            .collect();
843        let norm_source = norm_source?;
844
845        let norm_dest: Result<Vec<usize>> = destination
846            .iter()
847            .map(|&d| {
848                let dim = if d < 0 {
849                    (ndim as isize + d) as usize
850                } else {
851                    d as usize
852                };
853                if dim >= ndim {
854                    Err(TorshError::InvalidArgument(format!(
855                        "Dimension {} out of range for {}-D tensor",
856                        d, ndim
857                    )))
858                } else {
859                    Ok(dim)
860                }
861            })
862            .collect();
863        let norm_dest = norm_dest?;
864
865        // Check for duplicates in source
866        for i in 0..norm_source.len() {
867            for j in i + 1..norm_source.len() {
868                if norm_source[i] == norm_source[j] {
869                    return Err(TorshError::InvalidArgument(
870                        "repeated dim in source".to_string(),
871                    ));
872                }
873            }
874        }
875
876        // Check for duplicates in destination
877        for i in 0..norm_dest.len() {
878            for j in i + 1..norm_dest.len() {
879                if norm_dest[i] == norm_dest[j] {
880                    return Err(TorshError::InvalidArgument(
881                        "repeated dim in destination".to_string(),
882                    ));
883                }
884            }
885        }
886
887        // Build permutation array by placing dims in final positions
888        let mut result_perm = vec![0; ndim];
889        let mut used = vec![false; ndim];
890
891        // Place source dims at destination positions
892        for (&src, &dst) in norm_source.iter().zip(norm_dest.iter()) {
893            result_perm[dst] = src;
894            used[dst] = true;
895        }
896
897        // Fill remaining positions with remaining dims in order
898        let remaining_dims: Vec<usize> = (0..ndim).filter(|d| !norm_source.contains(d)).collect();
899
900        let mut remaining_idx = 0;
901        for i in 0..ndim {
902            if !used[i] {
903                result_perm[i] = remaining_dims[remaining_idx];
904                remaining_idx += 1;
905            }
906        }
907
908        // Convert usize to i32 for permute
909        let perm_i32: Vec<i32> = result_perm.iter().map(|&d| d as i32).collect();
910        self.permute(&perm_i32)
911    }
912
913    /// Move axis from source position to destination position (alias for movedim)
914    ///
915    /// # PyTorch Compatibility
916    /// Equivalent to `torch.moveaxis(tensor, source, destination)`
917    ///
918    /// # Arguments
919    /// * `source` - Original positions of axes to move
920    /// * `destination` - Target positions for the axes
921    pub fn moveaxis(&self, source: &[isize], destination: &[isize]) -> Result<Self> {
922        self.movedim(source, destination)
923    }
924
925    /// Swap two dimensions
926    ///
927    /// # PyTorch Compatibility
928    /// Equivalent to `torch.swapaxes(tensor, axis0, axis1)` or `torch.swapdims(tensor, dim0, dim1)`
929    ///
930    /// # Arguments
931    /// * `axis0` - First dimension
932    /// * `axis1` - Second dimension
933    ///
934    /// # Examples
935    /// ```ignore
936    /// let x = Tensor::from_data(vec![1.0; 12], vec![2, 3, 2], DeviceType::Cpu)?;
937    /// let y = x.swapaxes(0, 2)?; // [2,3,2] -> [2,3,2] with dims 0 and 2 swapped
938    /// ```
939    pub fn swapaxes(&self, axis0: isize, axis1: isize) -> Result<Self> {
940        let ndim = self.ndim();
941
942        // Normalize dimensions
943        let dim0 = if axis0 < 0 {
944            (ndim as isize + axis0) as usize
945        } else {
946            axis0 as usize
947        };
948        let dim1 = if axis1 < 0 {
949            (ndim as isize + axis1) as usize
950        } else {
951            axis1 as usize
952        };
953
954        if dim0 >= ndim {
955            return Err(TorshError::InvalidArgument(format!(
956                "Dimension {} out of range for {}-D tensor",
957                axis0, ndim
958            )));
959        }
960        if dim1 >= ndim {
961            return Err(TorshError::InvalidArgument(format!(
962                "Dimension {} out of range for {}-D tensor",
963                axis1, ndim
964            )));
965        }
966
967        // Build permutation: swap dim0 and dim1
968        let mut perm: Vec<i32> = (0..ndim as i32).collect();
969        perm.swap(dim0, dim1);
970
971        self.permute(&perm)
972    }
973
974    /// Swap two dimensions (alias for swapaxes)
975    ///
976    /// # PyTorch Compatibility
977    /// Equivalent to `torch.swapdims(tensor, dim0, dim1)`
978    pub fn swapdims(&self, dim0: isize, dim1: isize) -> Result<Self> {
979        self.swapaxes(dim0, dim1)
980    }
981
982    /// Broadcast tensor to a new shape
983    ///
984    /// # PyTorch Compatibility
985    /// Equivalent to `torch.broadcast_to(tensor, shape)`
986    ///
987    /// # Arguments
988    /// * `shape` - Target shape for broadcasting
989    ///
990    /// # Examples
991    /// ```ignore
992    /// let x = Tensor::from_data(vec![1.0, 2.0], vec![2], DeviceType::Cpu)?;
993    /// let y = x.broadcast_to(&[3, 2])?; // Broadcast [2] to [3, 2]
994    /// ```
995    pub fn broadcast_to(&self, shape: &[usize]) -> Result<Self> {
996        // Use the existing expand method which handles broadcasting
997        self.expand(shape)
998    }
999
1000    /// Expand tensor to match another tensor's shape
1001    ///
1002    /// # PyTorch Compatibility
1003    /// Equivalent to `torch.expand_as(tensor, other)`
1004    ///
1005    /// # Arguments
1006    /// * `other` - Target tensor whose shape to match
1007    ///
1008    /// # Examples
1009    /// ```ignore
1010    /// let x = Tensor::from_data(vec![1.0, 2.0], vec![2], DeviceType::Cpu)?;
1011    /// let y = Tensor::from_data(vec![0.0; 6], vec![3, 2], DeviceType::Cpu)?;
1012    /// let z = x.expand_as(&y)?; // Expand x to match y's shape [3, 2]
1013    /// ```
1014    pub fn expand_as(&self, other: &Self) -> Result<Self> {
1015        self.broadcast_to(other.shape().dims())
1016    }
1017}
1018
1019#[cfg(test)]
1020mod tests {
1021    use super::*;
1022    use torsh_core::device::DeviceType;
1023
1024    #[test]
1025    fn test_tensor_view() {
1026        let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
1027        let tensor = Tensor::from_data(data, vec![2, 3], DeviceType::Cpu)
1028            .expect("tensor creation should succeed");
1029
1030        let reshaped = tensor.view(&[3, 2]).expect("view should succeed");
1031        assert_eq!(reshaped.shape().dims(), &[3, 2]);
1032        assert_eq!(reshaped.numel(), 6);
1033    }
1034
1035    #[test]
1036    fn test_tensor_view_with_inference() {
1037        let data = vec![1.0f32; 24];
1038        let tensor = Tensor::from_data(data, vec![2, 3, 4], DeviceType::Cpu)
1039            .expect("tensor creation should succeed");
1040
1041        let reshaped = tensor.view(&[6, -1]).expect("view should succeed");
1042        assert_eq!(reshaped.shape().dims(), &[6, 4]);
1043    }
1044
1045    #[test]
1046    fn test_tensor_slice() {
1047        let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
1048        let tensor = Tensor::from_data(data, vec![2, 3], DeviceType::Cpu)
1049            .expect("tensor creation should succeed");
1050
1051        let slice = tensor.slice_tensor(1, 1, 3).expect("slice should succeed");
1052        assert_eq!(slice.shape().dims(), &[2, 2]);
1053    }
1054
1055    #[test]
1056    fn test_tensor_transpose() {
1057        let data = vec![1.0f32, 2.0, 3.0, 4.0];
1058        let tensor = Tensor::from_data(data, vec![2, 2], DeviceType::Cpu)
1059            .expect("tensor creation should succeed");
1060
1061        let transposed = tensor.transpose(0, 1).expect("transpose should succeed");
1062        assert_eq!(transposed.shape().dims(), &[2, 2]);
1063        assert_eq!(
1064            transposed.get(&[0, 1]).expect("data access should succeed"),
1065            3.0
1066        );
1067        assert_eq!(
1068            transposed.get(&[1, 0]).expect("data access should succeed"),
1069            2.0
1070        );
1071    }
1072
1073    #[test]
1074    fn test_tensor_squeeze_unsqueeze() {
1075        let data = vec![1.0f32, 2.0, 3.0];
1076        let tensor = Tensor::from_data(data, vec![1, 3], DeviceType::Cpu)
1077            .expect("tensor creation should succeed");
1078
1079        let squeezed = tensor.squeeze(0).expect("squeeze should succeed");
1080        assert_eq!(squeezed.shape().dims(), &[3]);
1081
1082        let unsqueezed = squeezed.unsqueeze(0).expect("unsqueeze should succeed");
1083        assert_eq!(unsqueezed.shape().dims(), &[1, 3]);
1084    }
1085
1086    #[test]
1087    fn test_tensor_permute() {
1088        let data = vec![1.0f32; 24];
1089        let tensor = Tensor::from_data(data, vec![2, 3, 4], DeviceType::Cpu)
1090            .expect("tensor creation should succeed");
1091
1092        let permuted = tensor.permute(&[2, 0, 1]).expect("permute should succeed");
1093        assert_eq!(permuted.shape().dims(), &[4, 2, 3]);
1094    }
1095
1096    #[test]
1097    fn test_is_contiguous() {
1098        let data = vec![1.0f32, 2.0, 3.0, 4.0];
1099        let tensor = Tensor::from_data(data, vec![2, 2], DeviceType::Cpu)
1100            .expect("tensor creation should succeed");
1101        assert!(tensor.is_contiguous());
1102
1103        let transposed = tensor
1104            .transpose_view(0, 1)
1105            .expect("transpose view should succeed");
1106        assert!(!transposed.is_contiguous());
1107
1108        let contiguous = transposed.contiguous().expect("contiguous should succeed");
1109        assert!(contiguous.is_contiguous());
1110    }
1111
1112    #[test]
1113    fn test_expand() {
1114        let data = vec![1.0f32, 2.0];
1115        let tensor = Tensor::from_data(data, vec![1, 2], DeviceType::Cpu)
1116            .expect("tensor creation should succeed");
1117
1118        let expanded = tensor.expand(&[3, 2]).expect("expand should succeed");
1119        assert_eq!(expanded.shape().dims(), &[3, 2]);
1120        assert_eq!(expanded.numel(), 6);
1121    }
1122
1123    #[test]
1124    fn test_view_error_handling() {
1125        let data = vec![1.0f32, 2.0, 3.0];
1126        let tensor = Tensor::from_data(data, vec![3], DeviceType::Cpu)
1127            .expect("tensor creation should succeed");
1128
1129        // Should fail - wrong total size
1130        assert!(tensor.view(&[2, 2]).is_err());
1131
1132        // Should fail - multiple -1
1133        assert!(tensor.view(&[-1, -1]).is_err());
1134    }
1135
1136    #[test]
1137    fn test_movedim_single() {
1138        let tensor = Tensor::from_data(vec![1.0f32; 24], vec![2, 3, 4], DeviceType::Cpu)
1139            .expect("tensor creation should succeed");
1140
1141        // Move dim 0 to position 2: [2,3,4] -> [3,4,2]
1142        let result = tensor.movedim(&[0], &[2]).expect("movedim should succeed");
1143        assert_eq!(result.shape().dims(), &[3, 4, 2]);
1144    }
1145
1146    #[test]
1147    fn test_movedim_multiple() {
1148        let tensor = Tensor::from_data(vec![1.0f32; 24], vec![2, 3, 4], DeviceType::Cpu)
1149            .expect("tensor creation should succeed");
1150
1151        // Move dims [0, 1] to positions [2, 0]: [2,3,4] -> [3,4,2]
1152        let result = tensor
1153            .movedim(&[0, 1], &[2, 0])
1154            .expect("movedim should succeed");
1155        assert_eq!(result.shape().dims(), &[3, 4, 2]);
1156    }
1157
1158    #[test]
1159    fn test_movedim_negative_indices() {
1160        let tensor = Tensor::from_data(vec![1.0f32; 24], vec![2, 3, 4], DeviceType::Cpu)
1161            .expect("tensor creation should succeed");
1162
1163        // Move last dim to first position: [2,3,4] -> [4,2,3]
1164        let result = tensor.movedim(&[-1], &[0]).expect("movedim should succeed");
1165        assert_eq!(result.shape().dims(), &[4, 2, 3]);
1166    }
1167
1168    #[test]
1169    fn test_moveaxis_alias() {
1170        let tensor = Tensor::from_data(vec![1.0f32; 24], vec![2, 3, 4], DeviceType::Cpu)
1171            .expect("tensor creation should succeed");
1172
1173        let result1 = tensor.movedim(&[0], &[2]).expect("movedim should succeed");
1174        let result2 = tensor
1175            .moveaxis(&[0], &[2])
1176            .expect("moveaxis should succeed");
1177        assert_eq!(result1.shape().dims(), result2.shape().dims());
1178    }
1179
1180    #[test]
1181    fn test_swapaxes_simple() {
1182        let tensor = Tensor::from_data(
1183            vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0],
1184            vec![2, 3],
1185            DeviceType::Cpu,
1186        )
1187        .expect("tensor creation should succeed");
1188
1189        // Swap dims 0 and 1: [2,3] -> [3,2]
1190        let result = tensor.swapaxes(0, 1).expect("swapaxes should succeed");
1191        assert_eq!(result.shape().dims(), &[3, 2]);
1192    }
1193
1194    #[test]
1195    fn test_swapaxes_3d() {
1196        let tensor = Tensor::from_data(vec![1.0f32; 24], vec![2, 3, 4], DeviceType::Cpu)
1197            .expect("tensor creation should succeed");
1198
1199        // Swap dims 0 and 2: [2,3,4] -> [4,3,2]
1200        let result = tensor.swapaxes(0, 2).expect("swapaxes should succeed");
1201        assert_eq!(result.shape().dims(), &[4, 3, 2]);
1202    }
1203
1204    #[test]
1205    fn test_swapaxes_negative_indices() {
1206        let tensor = Tensor::from_data(vec![1.0f32; 24], vec![2, 3, 4], DeviceType::Cpu)
1207            .expect("tensor creation should succeed");
1208
1209        // Swap last two dims: [2,3,4] -> [2,4,3]
1210        let result = tensor.swapaxes(-1, -2).expect("swapaxes should succeed");
1211        assert_eq!(result.shape().dims(), &[2, 4, 3]);
1212    }
1213
1214    #[test]
1215    fn test_swapdims_alias() {
1216        let tensor = Tensor::from_data(vec![1.0f32; 24], vec![2, 3, 4], DeviceType::Cpu)
1217            .expect("tensor creation should succeed");
1218
1219        let result1 = tensor.swapaxes(0, 2).expect("swapaxes should succeed");
1220        let result2 = tensor.swapdims(0, 2).expect("swapdims should succeed");
1221        assert_eq!(result1.shape().dims(), result2.shape().dims());
1222    }
1223
1224    #[test]
1225    fn test_broadcast_to_same_shape() {
1226        let tensor = Tensor::from_data(vec![1.0f32, 2.0, 3.0, 4.0], vec![2, 2], DeviceType::Cpu)
1227            .expect("tensor creation should succeed");
1228
1229        let result = tensor
1230            .broadcast_to(&[2, 2])
1231            .expect("broadcast_to should succeed");
1232        assert_eq!(result.shape().dims(), &[2, 2]);
1233    }
1234
1235    #[test]
1236    fn test_broadcast_to_expand_dim() {
1237        let tensor = Tensor::from_data(vec![1.0f32, 2.0], vec![1, 2], DeviceType::Cpu)
1238            .expect("tensor creation should succeed");
1239
1240        // Broadcast [1, 2] to [3, 2]
1241        let result = tensor
1242            .broadcast_to(&[3, 2])
1243            .expect("broadcast_to should succeed");
1244        assert_eq!(result.shape().dims(), &[3, 2]);
1245    }
1246
1247    #[test]
1248    fn test_expand_as_basic() {
1249        let tensor = Tensor::from_data(vec![1.0f32, 2.0], vec![1, 2], DeviceType::Cpu)
1250            .expect("tensor creation should succeed");
1251
1252        let target = Tensor::from_data(vec![0.0f32; 6], vec![3, 2], DeviceType::Cpu)
1253            .expect("tensor creation should succeed");
1254
1255        let result = tensor.expand_as(&target).expect("expand_as should succeed");
1256        assert_eq!(result.shape().dims(), target.shape().dims());
1257        assert_eq!(result.shape().dims(), &[3, 2]);
1258    }
1259}