train_station/tensor/transform/
permute.rs

1//! Tensor dimension permutation operations
2//!
3//! This module provides tensor permutation functionality that rearranges the
4//! dimensions of a tensor according to a specified order. Permutation is a
5//! fundamental tensor transformation operation used in machine learning for
6//! reordering tensor axes, preparing data for specific operations, and
7//! implementing complex tensor manipulations.
8//!
9//! # Operations
10//!
11//! * `permute()` - Rearrange tensor dimensions according to specified order
12//!
13//! # Performance Characteristics
14//!
15//! * **Zero-Copy Operation**: Returns a view with reordered strides, avoiding data copying
16//! * **Memory Efficient**: Reuses existing tensor data through stride manipulation
17//! * **Gradient Tracking**: Full GradTrack support for automatic differentiation
18//! * **Shape Transformation**: Changes dimension order while preserving total elements
19//!
20//! # Examples
21//!
22//! ```
23//! use train_station::Tensor;
24//!
25//! // Permute 2D tensor dimensions
26//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
27//! let permuted = tensor.permute(vec![1, 0]);
28//! assert_eq!(permuted.shape().dims(), vec![3, 2]);
29//! assert_eq!(permuted.get(&[0, 0]), 1.0);
30//! assert_eq!(permuted.get(&[1, 0]), 2.0);
31//! assert_eq!(permuted.get(&[2, 1]), 6.0);
32//! ```
33//!
34//! ```
35//! use train_station::Tensor;
36//!
37//! // Permute 3D tensor dimensions
38//! let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
39//! let tensor = Tensor::from_slice(&data, vec![2, 3, 4]).unwrap();
40//! let permuted = tensor.permute(vec![2, 0, 1]);
41//! assert_eq!(permuted.shape().dims(), vec![4, 2, 3]);
42//! ```
43//!
44//! # Gradient Tracking
45//!
46//! The permute operation supports automatic gradient tracking through
47//! the GradTrack system. When `requires_grad` is enabled, the operation
48//! registers a gradient function that applies the inverse permutation
49//! during backward passes.
50
51use crate::gradtrack::{GradEngine, GradFn};
52use crate::tensor::core::Tensor;
53
54impl Tensor {
55    /// Permute tensor dimensions according to specified order
56    ///
57    /// Rearranges the dimensions of the tensor according to the provided
58    /// dimension order. This operation returns a view with reordered strides,
59    /// avoiding data copying while changing the logical arrangement of the
60    /// tensor's dimensions.
61    ///
62    /// The permutation is specified as a vector where each element represents
63    /// the new position of the corresponding dimension from the original tensor.
64    /// For example, `permute(vec![1, 0])` swaps the first two dimensions.
65    ///
66    /// # Arguments
67    ///
68    /// * `dims` - Vector specifying the new order of dimensions (must have length equal to tensor rank)
69    ///
70    /// # Returns
71    ///
72    /// A new tensor view with rearranged dimensions and correspondingly
73    /// adjusted strides. The total number of elements remains unchanged.
74    ///
75    /// # Panics
76    ///
77    /// * If `dims` length does not equal the tensor rank
78    /// * If any dimension index is out of bounds for the tensor rank
79    /// * If `dims` contains duplicate dimension indices
80    ///
81    /// # Examples
82    ///
83    /// ```
84    /// use train_station::Tensor;
85    ///
86    /// // Permute 2D tensor (swap dimensions)
87    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
88    /// let permuted = tensor.permute(vec![1, 0]);
89    /// assert_eq!(permuted.shape().dims(), vec![3, 2]);
90    /// assert_eq!(permuted.get(&[0, 0]), 1.0);
91    /// assert_eq!(permuted.get(&[1, 0]), 2.0);
92    /// assert_eq!(permuted.get(&[2, 1]), 6.0);
93    /// ```
94    ///
95    /// ```
96    /// use train_station::Tensor;
97    ///
98    /// // Permute 3D tensor (reorder dimensions)
99    /// let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
100    /// let tensor = Tensor::from_slice(&data, vec![2, 3, 4]).unwrap();
101    /// let permuted = tensor.permute(vec![2, 0, 1]);
102    /// assert_eq!(permuted.shape().dims(), vec![4, 2, 3]);
103    /// assert_eq!(permuted.size(), 24); // Total elements unchanged
104    /// ```
105    ///
106    /// ```
107    /// use train_station::Tensor;
108    ///
109    /// // Permute with gradient tracking
110    /// let mut tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
111    /// tensor.set_requires_grad(true);
112    ///
113    /// let permuted = tensor.permute(vec![1, 0]);
114    /// assert!(permuted.requires_grad());
115    /// assert_eq!(permuted.shape().dims(), vec![2, 2]);
116    /// ```
117    ///
118    /// ```
119    /// use train_station::Tensor;
120    ///
121    /// // Identity permutation (no change)
122    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
123    /// let permuted = tensor.permute(vec![0, 1]);
124    /// assert_eq!(permuted.shape().dims(), vec![2, 2]);
125    /// assert_eq!(permuted.get(&[0, 0]), 1.0);
126    /// assert_eq!(permuted.get(&[1, 1]), 4.0);
127    /// ```
128    ///
129    /// # Performance
130    ///
131    /// - **Time Complexity**: O(1) - Returns a view with reordered strides
132    /// - **Memory Usage**: No additional memory allocation (view operation)
133    /// - **Gradient Tracking**: Preserves gradient requirements and tracking
134    ///
135    /// # Relationship to Other Operations
136    ///
137    /// This operation is similar to `transpose()` but more general:
138    /// - `transpose(dim0, dim1)` is equivalent to `permute()` with a swap of two dimensions
139    /// - `permute()` can handle arbitrary dimension reordering for tensors of any rank
140    ///
141    /// # Memory Layout
142    ///
143    /// The permuted tensor maintains the same underlying data but with
144    /// reordered strides. This means the tensor becomes non-contiguous
145    /// unless the permutation is the identity permutation.
146    #[track_caller]
147    pub fn permute(&self, dims: Vec<usize>) -> Tensor {
148        let rank = self.shape().rank();
149        assert_eq!(
150            dims.len(),
151            rank,
152            "permute order must have length equal to rank"
153        );
154
155        // Validate dims has all unique values in range
156        {
157            let mut seen = vec![false; rank];
158            for &d in &dims {
159                assert!(
160                    d < rank,
161                    "permute index {} out of bounds for rank {}",
162                    d,
163                    rank
164                );
165                assert!(!seen[d], "duplicate dimension {} in permute", d);
166                seen[d] = true;
167            }
168        }
169
170        // Delegate to core view: generalized permutation view
171        let mut result = match crate::tensor::core::view::transpose_view(self, &dims) {
172            Ok(v) => v,
173            Err(e) => panic!("permute view error: {:?}", e),
174        };
175
176        // GradTrack: register permute for backward (inverse permutation)
177        if self.requires_grad() {
178            result.set_requires_grad(true);
179            let grad_fn = GradFn::Permute {
180                dims: dims.clone(),
181                input_shape: self.shape().dims().to_vec(),
182            };
183            result.set_grad_fn(grad_fn.clone());
184            GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
185        }
186
187        result
188    }
189}
190
191#[cfg(test)]
192mod tests {
193    use super::*;
194
195    #[test]
196    fn test_permute_basic_2d() {
197        let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
198        let y = x.permute(vec![1, 0]);
199        assert_eq!(y.shape().dims(), vec![3, 2]);
200        assert_eq!(y.get(&[0, 0]), 1.0);
201        assert_eq!(y.get(&[1, 0]), 2.0);
202        assert_eq!(y.get(&[2, 1]), 6.0);
203    }
204}