train_station/tensor/transform/permute.rs
1//! Tensor dimension permutation operations
2//!
3//! This module provides tensor permutation functionality that rearranges the
4//! dimensions of a tensor according to a specified order. Permutation is a
5//! fundamental tensor transformation operation used in machine learning for
6//! reordering tensor axes, preparing data for specific operations, and
7//! implementing complex tensor manipulations.
8//!
9//! # Operations
10//!
11//! * `permute()` - Rearrange tensor dimensions according to specified order
12//!
13//! # Performance Characteristics
14//!
15//! * **Zero-Copy Operation**: Returns a view with reordered strides, avoiding data copying
16//! * **Memory Efficient**: Reuses existing tensor data through stride manipulation
17//! * **Gradient Tracking**: Full GradTrack support for automatic differentiation
18//! * **Shape Transformation**: Changes dimension order while preserving total elements
19//!
20//! # Examples
21//!
22//! ```
23//! use train_station::Tensor;
24//!
25//! // Permute 2D tensor dimensions
26//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
27//! let permuted = tensor.permute(vec![1, 0]);
28//! assert_eq!(permuted.shape().dims, vec![3, 2]);
29//! assert_eq!(permuted.get(&[0, 0]), 1.0);
30//! assert_eq!(permuted.get(&[1, 0]), 2.0);
31//! assert_eq!(permuted.get(&[2, 1]), 6.0);
32//! ```
33//!
34//! ```
35//! use train_station::Tensor;
36//!
37//! // Permute 3D tensor dimensions
38//! let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
39//! let tensor = Tensor::from_slice(&data, vec![2, 3, 4]).unwrap();
40//! let permuted = tensor.permute(vec![2, 0, 1]);
41//! assert_eq!(permuted.shape().dims, vec![4, 2, 3]);
42//! ```
43//!
44//! # Gradient Tracking
45//!
46//! The permute operation supports automatic gradient tracking through
47//! the GradTrack system. When `requires_grad` is enabled, the operation
48//! registers a gradient function that applies the inverse permutation
49//! during backward passes.
50
51use crate::gradtrack::{GradEngine, GradFn};
52use crate::tensor::core::Tensor;
53use crate::tensor::Shape;
54
55impl Tensor {
56 /// Permute tensor dimensions according to specified order
57 ///
58 /// Rearranges the dimensions of the tensor according to the provided
59 /// dimension order. This operation returns a view with reordered strides,
60 /// avoiding data copying while changing the logical arrangement of the
61 /// tensor's dimensions.
62 ///
63 /// The permutation is specified as a vector where each element represents
64 /// the new position of the corresponding dimension from the original tensor.
65 /// For example, `permute(vec![1, 0])` swaps the first two dimensions.
66 ///
67 /// # Arguments
68 ///
69 /// * `dims` - Vector specifying the new order of dimensions (must have length equal to tensor rank)
70 ///
71 /// # Returns
72 ///
73 /// A new tensor view with rearranged dimensions and correspondingly
74 /// adjusted strides. The total number of elements remains unchanged.
75 ///
76 /// # Panics
77 ///
78 /// * If `dims` length does not equal the tensor rank
79 /// * If any dimension index is out of bounds for the tensor rank
80 /// * If `dims` contains duplicate dimension indices
81 ///
82 /// # Examples
83 ///
84 /// ```
85 /// use train_station::Tensor;
86 ///
87 /// // Permute 2D tensor (swap dimensions)
88 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
89 /// let permuted = tensor.permute(vec![1, 0]);
90 /// assert_eq!(permuted.shape().dims, vec![3, 2]);
91 /// assert_eq!(permuted.get(&[0, 0]), 1.0);
92 /// assert_eq!(permuted.get(&[1, 0]), 2.0);
93 /// assert_eq!(permuted.get(&[2, 1]), 6.0);
94 /// ```
95 ///
96 /// ```
97 /// use train_station::Tensor;
98 ///
99 /// // Permute 3D tensor (reorder dimensions)
100 /// let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
101 /// let tensor = Tensor::from_slice(&data, vec![2, 3, 4]).unwrap();
102 /// let permuted = tensor.permute(vec![2, 0, 1]);
103 /// assert_eq!(permuted.shape().dims, vec![4, 2, 3]);
104 /// assert_eq!(permuted.size(), 24); // Total elements unchanged
105 /// ```
106 ///
107 /// ```
108 /// use train_station::Tensor;
109 ///
110 /// // Permute with gradient tracking
111 /// let mut tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
112 /// tensor.set_requires_grad(true);
113 ///
114 /// let permuted = tensor.permute(vec![1, 0]);
115 /// assert!(permuted.requires_grad());
116 /// assert_eq!(permuted.shape().dims, vec![2, 2]);
117 /// ```
118 ///
119 /// ```
120 /// use train_station::Tensor;
121 ///
122 /// // Identity permutation (no change)
123 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
124 /// let permuted = tensor.permute(vec![0, 1]);
125 /// assert_eq!(permuted.shape().dims, vec![2, 2]);
126 /// assert_eq!(permuted.get(&[0, 0]), 1.0);
127 /// assert_eq!(permuted.get(&[1, 1]), 4.0);
128 /// ```
129 ///
130 /// # Performance
131 ///
132 /// - **Time Complexity**: O(1) - Returns a view with reordered strides
133 /// - **Memory Usage**: No additional memory allocation (view operation)
134 /// - **Gradient Tracking**: Preserves gradient requirements and tracking
135 ///
136 /// # Relationship to Other Operations
137 ///
138 /// This operation is similar to `transpose()` but more general:
139 /// - `transpose(dim0, dim1)` is equivalent to `permute()` with a swap of two dimensions
140 /// - `permute()` can handle arbitrary dimension reordering for tensors of any rank
141 ///
142 /// # Memory Layout
143 ///
144 /// The permuted tensor maintains the same underlying data but with
145 /// reordered strides. This means the tensor becomes non-contiguous
146 /// unless the permutation is the identity permutation.
147 #[track_caller]
148 pub fn permute(&self, dims: Vec<usize>) -> Tensor {
149 let rank = self.shape().rank();
150 assert_eq!(
151 dims.len(),
152 rank,
153 "permute order must have length equal to rank"
154 );
155
156 // Validate dims has all unique values in range
157 {
158 let mut seen = vec![false; rank];
159 for &d in &dims {
160 assert!(
161 d < rank,
162 "permute index {} out of bounds for rank {}",
163 d,
164 rank
165 );
166 assert!(!seen[d], "duplicate dimension {} in permute", d);
167 seen[d] = true;
168 }
169 }
170
171 // Compute new dims and strides for view
172 let mut new_dims = Vec::with_capacity(rank);
173 for &d in &dims {
174 new_dims.push(self.shape().dims[d]);
175 }
176 // Reorder strides accordingly
177 let mut new_strides = Vec::with_capacity(rank);
178 for &d in &dims {
179 new_strides.push(self.stride(d));
180 }
181
182 // Create a non-copy view with strided layout
183 let view_shape = Shape::as_view(new_dims, new_strides);
184 let mut result = self.create_view_with_shape(view_shape);
185
186 // GradTrack: register permute for backward (inverse permutation)
187 if self.requires_grad() {
188 result.set_requires_grad(true);
189 let grad_fn = GradFn::Permute {
190 dims: dims.clone(),
191 input_shape: self.shape().dims.clone(),
192 };
193 result.set_grad_fn(grad_fn.clone());
194 GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
195 }
196
197 result
198 }
199}
200
201#[cfg(test)]
202mod tests {
203 use super::*;
204
205 #[test]
206 fn test_permute_basic_2d() {
207 let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
208 let y = x.permute(vec![1, 0]);
209 assert_eq!(y.shape().dims, vec![3, 2]);
210 assert_eq!(y.get(&[0, 0]), 1.0);
211 assert_eq!(y.get(&[1, 0]), 2.0);
212 assert_eq!(y.get(&[2, 1]), 6.0);
213 }
214}