train_station/tensor/transform/permute.rs
1//! Tensor dimension permutation operations
2//!
3//! This module provides tensor permutation functionality that rearranges the
4//! dimensions of a tensor according to a specified order. Permutation is a
5//! fundamental tensor transformation operation used in machine learning for
6//! reordering tensor axes, preparing data for specific operations, and
7//! implementing complex tensor manipulations.
8//!
9//! # Operations
10//!
11//! * `permute()` - Rearrange tensor dimensions according to specified order
12//!
13//! # Performance Characteristics
14//!
15//! * **Zero-Copy Operation**: Returns a view with reordered strides, avoiding data copying
16//! * **Memory Efficient**: Reuses existing tensor data through stride manipulation
17//! * **Gradient Tracking**: Full GradTrack support for automatic differentiation
18//! * **Shape Transformation**: Changes dimension order while preserving total elements
19//!
20//! # Examples
21//!
22//! ```
23//! use train_station::Tensor;
24//!
25//! // Permute 2D tensor dimensions
26//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
27//! let permuted = tensor.permute(vec![1, 0]);
28//! assert_eq!(permuted.shape().dims(), vec![3, 2]);
29//! assert_eq!(permuted.get(&[0, 0]), 1.0);
30//! assert_eq!(permuted.get(&[1, 0]), 2.0);
31//! assert_eq!(permuted.get(&[2, 1]), 6.0);
32//! ```
33//!
34//! ```
35//! use train_station::Tensor;
36//!
37//! // Permute 3D tensor dimensions
38//! let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
39//! let tensor = Tensor::from_slice(&data, vec![2, 3, 4]).unwrap();
40//! let permuted = tensor.permute(vec![2, 0, 1]);
41//! assert_eq!(permuted.shape().dims(), vec![4, 2, 3]);
42//! ```
43//!
44//! # Gradient Tracking
45//!
46//! The permute operation supports automatic gradient tracking through
47//! the GradTrack system. When `requires_grad` is enabled, the operation
48//! registers a gradient function that applies the inverse permutation
49//! during backward passes.
50
51use crate::gradtrack::{GradEngine, GradFn};
52use crate::tensor::core::Tensor;
53
54impl Tensor {
55 /// Permute tensor dimensions according to specified order
56 ///
57 /// Rearranges the dimensions of the tensor according to the provided
58 /// dimension order. This operation returns a view with reordered strides,
59 /// avoiding data copying while changing the logical arrangement of the
60 /// tensor's dimensions.
61 ///
62 /// The permutation is specified as a vector where each element represents
63 /// the new position of the corresponding dimension from the original tensor.
64 /// For example, `permute(vec![1, 0])` swaps the first two dimensions.
65 ///
66 /// # Arguments
67 ///
68 /// * `dims` - Vector specifying the new order of dimensions (must have length equal to tensor rank)
69 ///
70 /// # Returns
71 ///
72 /// A new tensor view with rearranged dimensions and correspondingly
73 /// adjusted strides. The total number of elements remains unchanged.
74 ///
75 /// # Panics
76 ///
77 /// * If `dims` length does not equal the tensor rank
78 /// * If any dimension index is out of bounds for the tensor rank
79 /// * If `dims` contains duplicate dimension indices
80 ///
81 /// # Examples
82 ///
83 /// ```
84 /// use train_station::Tensor;
85 ///
86 /// // Permute 2D tensor (swap dimensions)
87 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
88 /// let permuted = tensor.permute(vec![1, 0]);
89 /// assert_eq!(permuted.shape().dims(), vec![3, 2]);
90 /// assert_eq!(permuted.get(&[0, 0]), 1.0);
91 /// assert_eq!(permuted.get(&[1, 0]), 2.0);
92 /// assert_eq!(permuted.get(&[2, 1]), 6.0);
93 /// ```
94 ///
95 /// ```
96 /// use train_station::Tensor;
97 ///
98 /// // Permute 3D tensor (reorder dimensions)
99 /// let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
100 /// let tensor = Tensor::from_slice(&data, vec![2, 3, 4]).unwrap();
101 /// let permuted = tensor.permute(vec![2, 0, 1]);
102 /// assert_eq!(permuted.shape().dims(), vec![4, 2, 3]);
103 /// assert_eq!(permuted.size(), 24); // Total elements unchanged
104 /// ```
105 ///
106 /// ```
107 /// use train_station::Tensor;
108 ///
109 /// // Permute with gradient tracking
110 /// let mut tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
111 /// tensor.set_requires_grad(true);
112 ///
113 /// let permuted = tensor.permute(vec![1, 0]);
114 /// assert!(permuted.requires_grad());
115 /// assert_eq!(permuted.shape().dims(), vec![2, 2]);
116 /// ```
117 ///
118 /// ```
119 /// use train_station::Tensor;
120 ///
121 /// // Identity permutation (no change)
122 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
123 /// let permuted = tensor.permute(vec![0, 1]);
124 /// assert_eq!(permuted.shape().dims(), vec![2, 2]);
125 /// assert_eq!(permuted.get(&[0, 0]), 1.0);
126 /// assert_eq!(permuted.get(&[1, 1]), 4.0);
127 /// ```
128 ///
129 /// # Performance
130 ///
131 /// - **Time Complexity**: O(1) - Returns a view with reordered strides
132 /// - **Memory Usage**: No additional memory allocation (view operation)
133 /// - **Gradient Tracking**: Preserves gradient requirements and tracking
134 ///
135 /// # Relationship to Other Operations
136 ///
137 /// This operation is similar to `transpose()` but more general:
138 /// - `transpose(dim0, dim1)` is equivalent to `permute()` with a swap of two dimensions
139 /// - `permute()` can handle arbitrary dimension reordering for tensors of any rank
140 ///
141 /// # Memory Layout
142 ///
143 /// The permuted tensor maintains the same underlying data but with
144 /// reordered strides. This means the tensor becomes non-contiguous
145 /// unless the permutation is the identity permutation.
146 #[track_caller]
147 pub fn permute(&self, dims: Vec<usize>) -> Tensor {
148 let rank = self.shape().rank();
149 assert_eq!(
150 dims.len(),
151 rank,
152 "permute order must have length equal to rank"
153 );
154
155 // Validate dims has all unique values in range
156 {
157 let mut seen = vec![false; rank];
158 for &d in &dims {
159 assert!(
160 d < rank,
161 "permute index {} out of bounds for rank {}",
162 d,
163 rank
164 );
165 assert!(!seen[d], "duplicate dimension {} in permute", d);
166 seen[d] = true;
167 }
168 }
169
170 // Delegate to core view: generalized permutation view
171 let mut result = match crate::tensor::core::view::transpose_view(self, &dims) {
172 Ok(v) => v,
173 Err(e) => panic!("permute view error: {:?}", e),
174 };
175
176 // GradTrack: register permute for backward (inverse permutation)
177 if self.requires_grad() {
178 result.set_requires_grad(true);
179 let grad_fn = GradFn::Permute {
180 dims: dims.clone(),
181 input_shape: self.shape().dims().to_vec(),
182 };
183 result.set_grad_fn(grad_fn.clone());
184 GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
185 }
186
187 result
188 }
189}
190
191#[cfg(test)]
192mod tests {
193 use super::*;
194
195 #[test]
196 fn test_permute_basic_2d() {
197 let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
198 let y = x.permute(vec![1, 0]);
199 assert_eq!(y.shape().dims(), vec![3, 2]);
200 assert_eq!(y.get(&[0, 0]), 1.0);
201 assert_eq!(y.get(&[1, 0]), 2.0);
202 assert_eq!(y.get(&[2, 1]), 6.0);
203 }
204}