train_station/tensor/transform/
permute.rs

1//! Tensor dimension permutation operations
2//!
3//! This module provides tensor permutation functionality that rearranges the
4//! dimensions of a tensor according to a specified order. Permutation is a
5//! fundamental tensor transformation operation used in machine learning for
6//! reordering tensor axes, preparing data for specific operations, and
7//! implementing complex tensor manipulations.
8//!
9//! # Operations
10//!
11//! * `permute()` - Rearrange tensor dimensions according to specified order
12//!
13//! # Performance Characteristics
14//!
15//! * **Zero-Copy Operation**: Returns a view with reordered strides, avoiding data copying
16//! * **Memory Efficient**: Reuses existing tensor data through stride manipulation
17//! * **Gradient Tracking**: Full GradTrack support for automatic differentiation
18//! * **Shape Transformation**: Changes dimension order while preserving total elements
19//!
20//! # Examples
21//!
22//! ```
23//! use train_station::Tensor;
24//!
25//! // Permute 2D tensor dimensions
26//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
27//! let permuted = tensor.permute(vec![1, 0]);
28//! assert_eq!(permuted.shape().dims, vec![3, 2]);
29//! assert_eq!(permuted.get(&[0, 0]), 1.0);
30//! assert_eq!(permuted.get(&[1, 0]), 2.0);
31//! assert_eq!(permuted.get(&[2, 1]), 6.0);
32//! ```
33//!
34//! ```
35//! use train_station::Tensor;
36//!
37//! // Permute 3D tensor dimensions
38//! let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
39//! let tensor = Tensor::from_slice(&data, vec![2, 3, 4]).unwrap();
40//! let permuted = tensor.permute(vec![2, 0, 1]);
41//! assert_eq!(permuted.shape().dims, vec![4, 2, 3]);
42//! ```
43//!
44//! # Gradient Tracking
45//!
46//! The permute operation supports automatic gradient tracking through
47//! the GradTrack system. When `requires_grad` is enabled, the operation
48//! registers a gradient function that applies the inverse permutation
49//! during backward passes.
50
51use crate::gradtrack::{GradEngine, GradFn};
52use crate::tensor::core::Tensor;
53use crate::tensor::Shape;
54
55impl Tensor {
56    /// Permute tensor dimensions according to specified order
57    ///
58    /// Rearranges the dimensions of the tensor according to the provided
59    /// dimension order. This operation returns a view with reordered strides,
60    /// avoiding data copying while changing the logical arrangement of the
61    /// tensor's dimensions.
62    ///
63    /// The permutation is specified as a vector where each element represents
64    /// the new position of the corresponding dimension from the original tensor.
65    /// For example, `permute(vec![1, 0])` swaps the first two dimensions.
66    ///
67    /// # Arguments
68    ///
69    /// * `dims` - Vector specifying the new order of dimensions (must have length equal to tensor rank)
70    ///
71    /// # Returns
72    ///
73    /// A new tensor view with rearranged dimensions and correspondingly
74    /// adjusted strides. The total number of elements remains unchanged.
75    ///
76    /// # Panics
77    ///
78    /// * If `dims` length does not equal the tensor rank
79    /// * If any dimension index is out of bounds for the tensor rank
80    /// * If `dims` contains duplicate dimension indices
81    ///
82    /// # Examples
83    ///
84    /// ```
85    /// use train_station::Tensor;
86    ///
87    /// // Permute 2D tensor (swap dimensions)
88    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
89    /// let permuted = tensor.permute(vec![1, 0]);
90    /// assert_eq!(permuted.shape().dims, vec![3, 2]);
91    /// assert_eq!(permuted.get(&[0, 0]), 1.0);
92    /// assert_eq!(permuted.get(&[1, 0]), 2.0);
93    /// assert_eq!(permuted.get(&[2, 1]), 6.0);
94    /// ```
95    ///
96    /// ```
97    /// use train_station::Tensor;
98    ///
99    /// // Permute 3D tensor (reorder dimensions)
100    /// let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
101    /// let tensor = Tensor::from_slice(&data, vec![2, 3, 4]).unwrap();
102    /// let permuted = tensor.permute(vec![2, 0, 1]);
103    /// assert_eq!(permuted.shape().dims, vec![4, 2, 3]);
104    /// assert_eq!(permuted.size(), 24); // Total elements unchanged
105    /// ```
106    ///
107    /// ```
108    /// use train_station::Tensor;
109    ///
110    /// // Permute with gradient tracking
111    /// let mut tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
112    /// tensor.set_requires_grad(true);
113    ///
114    /// let permuted = tensor.permute(vec![1, 0]);
115    /// assert!(permuted.requires_grad());
116    /// assert_eq!(permuted.shape().dims, vec![2, 2]);
117    /// ```
118    ///
119    /// ```
120    /// use train_station::Tensor;
121    ///
122    /// // Identity permutation (no change)
123    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
124    /// let permuted = tensor.permute(vec![0, 1]);
125    /// assert_eq!(permuted.shape().dims, vec![2, 2]);
126    /// assert_eq!(permuted.get(&[0, 0]), 1.0);
127    /// assert_eq!(permuted.get(&[1, 1]), 4.0);
128    /// ```
129    ///
130    /// # Performance
131    ///
132    /// - **Time Complexity**: O(1) - Returns a view with reordered strides
133    /// - **Memory Usage**: No additional memory allocation (view operation)
134    /// - **Gradient Tracking**: Preserves gradient requirements and tracking
135    ///
136    /// # Relationship to Other Operations
137    ///
138    /// This operation is similar to `transpose()` but more general:
139    /// - `transpose(dim0, dim1)` is equivalent to `permute()` with a swap of two dimensions
140    /// - `permute()` can handle arbitrary dimension reordering for tensors of any rank
141    ///
142    /// # Memory Layout
143    ///
144    /// The permuted tensor maintains the same underlying data but with
145    /// reordered strides. This means the tensor becomes non-contiguous
146    /// unless the permutation is the identity permutation.
147    pub fn permute(&self, dims: Vec<usize>) -> Tensor {
148        let rank = self.shape().rank();
149        assert_eq!(
150            dims.len(),
151            rank,
152            "permute order must have length equal to rank"
153        );
154
155        // Validate dims has all unique values in range
156        {
157            let mut seen = vec![false; rank];
158            for &d in &dims {
159                assert!(
160                    d < rank,
161                    "permute index {} out of bounds for rank {}",
162                    d,
163                    rank
164                );
165                assert!(!seen[d], "duplicate dimension {} in permute", d);
166                seen[d] = true;
167            }
168        }
169
170        // Compute new dims and strides for view
171        let mut new_dims = Vec::with_capacity(rank);
172        for &d in &dims {
173            new_dims.push(self.shape().dims[d]);
174        }
175        // Reorder strides accordingly
176        let mut new_strides = Vec::with_capacity(rank);
177        for &d in &dims {
178            new_strides.push(self.stride(d));
179        }
180
181        // Create a non-copy view with strided layout
182        let view_shape = Shape::as_view(new_dims, new_strides);
183        let mut result = self.create_view_with_shape(view_shape);
184
185        // GradTrack: register permute for backward (inverse permutation)
186        if self.requires_grad() {
187            result.set_requires_grad(true);
188            let grad_fn = GradFn::Permute {
189                dims: dims.clone(),
190                input_shape: self.shape().dims.clone(),
191            };
192            result.set_grad_fn(grad_fn.clone());
193            GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
194        }
195
196        result
197    }
198}
199
200#[cfg(test)]
201mod tests {
202    use super::*;
203
204    #[test]
205    fn test_permute_basic_2d() {
206        let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
207        let y = x.permute(vec![1, 0]);
208        assert_eq!(y.shape().dims, vec![3, 2]);
209        assert_eq!(y.get(&[0, 0]), 1.0);
210        assert_eq!(y.get(&[1, 0]), 2.0);
211        assert_eq!(y.get(&[2, 1]), 6.0);
212    }
213}