train_station/tensor/transform/transpose.rs
1//! Tensor transpose operations
2//!
3//! This module provides tensor transpose functionality that swaps dimensions
4//! of tensors, effectively changing the memory access pattern and logical
5//! arrangement of data. Transposition is a fundamental tensor transformation
6//! operation used in machine learning for matrix operations, preparing data
7//! for specific layer types, and implementing complex tensor manipulations.
8//!
9//! # Operations
10//!
11//! * `transpose()` - Swap two specified dimensions of a tensor
12//! * `t()` - Matrix transpose (swap last two dimensions)
13//!
14//! # Performance Characteristics
15//!
16//! * **Zero-Copy Operations**: Returns a view when possible using stride manipulation
17//! * **Memory Efficient**: Reuses existing tensor data through view operations
18//! * **Cache Optimized**: Uses optimized copying when view operations are not possible
19//! * **Gradient Tracking**: Full GradTrack support for automatic differentiation
20//! * **Shape Transformation**: Changes dimension order while preserving total elements
21//!
22//! # Examples
23//!
24//! ```
25//! use train_station::Tensor;
26//!
27//! // Basic 2D transpose
28//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
29//! let transposed = tensor.transpose(0, 1);
30//! assert_eq!(transposed.shape().dims(), vec![3, 2]);
31//! ```
32//!
33//! ```
34//! use train_station::Tensor;
35//!
36//! // Matrix transpose convenience method
37//! let matrix = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
38//! let transposed = matrix.t();
39//! assert_eq!(transposed.shape().dims(), vec![2, 2]);
40//! ```
41//!
42//! # Gradient Tracking
43//!
44//! The transpose operations support automatic gradient tracking through
45//! the GradTrack system. When `requires_grad` is enabled, the operations
46//! register gradient functions that apply the inverse transpose during
47//! backward passes.
48
49use crate::tensor::core::Tensor;
50
51impl Tensor {
52 /// Transpose two dimensions of the tensor
53 ///
54 /// Swaps two specified dimensions of the tensor, modifying the shape
55 /// and memory access pattern. When possible, this operation returns
56 /// a zero-copy view using stride manipulation. For complex cases or
57 /// non-contiguous tensors, data is copied to ensure correct transposition.
58 ///
59 /// The transpose operation is its own inverse - applying transpose
60 /// twice with the same dimensions returns the original tensor.
61 ///
62 /// # Arguments
63 ///
64 /// * `dim0` - First dimension to swap (must be < tensor rank)
65 /// * `dim1` - Second dimension to swap (must be < tensor rank)
66 ///
67 /// # Returns
68 ///
69 /// A new tensor with the specified dimensions transposed. The total
70 /// number of elements remains unchanged.
71 ///
72 /// # Panics
73 ///
74 /// * If `dim0` is out of bounds for the tensor rank
75 /// * If `dim1` is out of bounds for the tensor rank
76 ///
77 /// # Examples
78 ///
79 /// ```
80 /// use train_station::Tensor;
81 ///
82 /// // Basic 2D transpose
83 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
84 /// let transposed = tensor.transpose(0, 1);
85 /// assert_eq!(transposed.shape().dims(), vec![3, 2]);
86 /// assert_eq!(transposed.get(&[0, 0]), 1.0);
87 /// assert_eq!(transposed.get(&[0, 1]), 4.0);
88 /// assert_eq!(transposed.get(&[1, 0]), 2.0);
89 /// assert_eq!(transposed.get(&[1, 1]), 5.0);
90 /// assert_eq!(transposed.get(&[2, 0]), 3.0);
91 /// assert_eq!(transposed.get(&[2, 1]), 6.0);
92 /// ```
93 ///
94 /// ```
95 /// use train_station::Tensor;
96 ///
97 /// // 3D tensor transpose
98 /// let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
99 /// let tensor = Tensor::from_slice(&data, vec![2, 3, 4]).unwrap();
100 /// let transposed = tensor.transpose(0, 1);
101 /// assert_eq!(transposed.shape().dims(), vec![3, 2, 4]);
102 /// ```
103 ///
104 /// ```
105 /// use train_station::Tensor;
106 ///
107 /// // Transpose with gradient tracking
108 /// let mut tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
109 /// tensor.set_requires_grad(true);
110 ///
111 /// let transposed = tensor.transpose(0, 1);
112 /// assert!(transposed.requires_grad());
113 /// assert_eq!(transposed.shape().dims(), vec![2, 2]);
114 /// ```
115 ///
116 /// ```
117 /// use train_station::Tensor;
118 ///
119 /// // Transpose same dimension (no change)
120 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
121 /// let result = tensor.transpose(1, 1);
122 /// assert_eq!(result.shape().dims(), tensor.shape().dims());
123 /// assert_eq!(result.get(&[0, 0]), tensor.get(&[0, 0]));
124 /// ```
125 ///
126 /// ```
127 /// use train_station::Tensor;
128 ///
129 /// // Transpose is its own inverse
130 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
131 /// let transposed = tensor.transpose(0, 1);
132 /// let double_transposed = transposed.transpose(0, 1);
133 /// assert_eq!(double_transposed.shape().dims(), tensor.shape().dims());
134 /// assert_eq!(double_transposed.get(&[0, 0]), tensor.get(&[0, 0]));
135 /// ```
136 ///
137 /// # Performance
138 ///
139 /// - **Contiguous tensors**: O(1) time complexity, returns a view
140 /// - **Non-contiguous tensors**: O(n) time complexity with data copying
141 /// - **Memory usage**: No additional allocation for view operations
142 /// - **Gradient tracking**: Preserves gradient requirements and tracking
143 ///
144 /// # Relationship to Other Operations
145 ///
146 /// This operation is related to other tensor transformations:
147 /// - `t()` - Convenience method for matrix transpose (last two dimensions)
148 /// - `permute()` - More general dimension reordering operation
149 /// - `reshape()` - Changes shape without changing dimension order
150 ///
151 /// # Memory Layout
152 ///
153 /// For contiguous tensors, transpose returns a view with modified strides,
154 /// making the tensor non-contiguous. For non-contiguous tensors or complex
155 /// cases, data is copied to ensure correct transposition.
156 #[track_caller]
157 pub fn transpose(&self, dim0: usize, dim1: usize) -> Tensor {
158 assert!(
159 dim0 < self.shape().rank(),
160 "dim0 {} out of bounds for tensor with rank {}",
161 dim0,
162 self.shape().rank()
163 );
164 assert!(
165 dim1 < self.shape().rank(),
166 "dim1 {} out of bounds for tensor with rank {}",
167 dim1,
168 self.shape().rank()
169 );
170
171 // If same dimension, return a clone
172 if dim0 == dim1 {
173 return self.clone();
174 }
175
176 // Build permutation and delegate to core view
177 let rank = self.shape().rank();
178 let mut perm: Vec<usize> = (0..rank).collect();
179 perm.swap(dim0, dim1);
180 let mut result = match crate::tensor::core::view::transpose_view(self, &perm) {
181 Ok(v) => v,
182 Err(e) => panic!("transpose view error: {:?}", e),
183 };
184
185 // GradTrack: register transpose for backward (transpose is its own inverse)
186 if self.requires_grad() {
187 result.set_requires_grad(true);
188 let grad_fn = crate::gradtrack::grad_fn::GradFn::Transpose {
189 dim0,
190 dim1,
191 input_shape: self.shape().dims().to_vec(),
192 };
193 result.set_grad_fn(grad_fn.clone());
194 crate::gradtrack::engine::GradEngine::register_operation(
195 result.id(),
196 vec![self.id()],
197 grad_fn,
198 );
199 }
200
201 result
202 }
203
204 /// Matrix transpose (transpose last two dimensions)
205 ///
206 /// Convenience method for the common case of matrix transposition.
207 /// For 2D tensors, this performs a standard matrix transpose.
208 /// For higher-dimensional tensors, this transposes the last two
209 /// dimensions, treating the tensor as a batch of matrices.
210 ///
211 /// This method is equivalent to `transpose(rank-2, rank-1)` where
212 /// `rank` is the number of dimensions in the tensor.
213 ///
214 /// # Returns
215 ///
216 /// A new tensor with the last two dimensions transposed
217 ///
218 /// # Panics
219 ///
220 /// * If the tensor has less than 2 dimensions
221 ///
222 /// # Examples
223 ///
224 /// ```
225 /// use train_station::Tensor;
226 ///
227 /// // 2D matrix transpose
228 /// let matrix = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
229 /// let transposed = matrix.t();
230 /// assert_eq!(transposed.shape().dims(), vec![2, 2]);
231 /// assert_eq!(transposed.get(&[0, 0]), 1.0);
232 /// assert_eq!(transposed.get(&[0, 1]), 3.0);
233 /// assert_eq!(transposed.get(&[1, 0]), 2.0);
234 /// assert_eq!(transposed.get(&[1, 1]), 4.0);
235 /// ```
236 ///
237 /// ```
238 /// use train_station::Tensor;
239 ///
240 /// // 3D tensor (batch of matrices)
241 /// let data: Vec<f32> = (0..12).map(|i| i as f32).collect();
242 /// let tensor = Tensor::from_slice(&data, vec![2, 2, 3]).unwrap();
243 /// let transposed = tensor.t();
244 /// assert_eq!(transposed.shape().dims(), vec![2, 3, 2]);
245 /// ```
246 ///
247 /// ```
248 /// use train_station::Tensor;
249 ///
250 /// // Matrix transpose with gradient tracking
251 /// let mut matrix = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
252 /// matrix.set_requires_grad(true);
253 ///
254 /// let transposed = matrix.t();
255 /// assert!(transposed.requires_grad());
256 /// assert_eq!(transposed.shape().dims(), vec![2, 2]);
257 /// ```
258 ///
259 /// # Performance
260 ///
261 /// - **Time Complexity**: Same as `transpose()` - O(1) for views, O(n) for copies
262 /// - **Memory Usage**: Same as `transpose()` - no allocation for views
263 /// - **Gradient Tracking**: Preserves gradient requirements and tracking
264 ///
265 /// # Relationship to Other Operations
266 ///
267 /// This operation is equivalent to:
268 /// ```rust
269 /// use train_station::Tensor;
270 ///
271 /// let tensor = Tensor::new(vec![2, 3, 4]);
272 /// let rank = tensor.shape().rank();
273 /// let transposed1 = tensor.t();
274 /// let transposed2 = tensor.transpose(rank - 2, rank - 1);
275 /// // transposed1 and transposed2 are identical
276 /// ```
277 #[track_caller]
278 pub fn t(&self) -> Tensor {
279 assert!(
280 self.shape().rank() >= 2,
281 "Matrix transpose requires at least 2 dimensions, got {}",
282 self.shape().rank()
283 );
284 let rank = self.shape().rank();
285 self.transpose(rank - 2, rank - 1)
286 }
287}
288
289#[cfg(test)]
290mod tests {
291 use super::*;
292
293 #[test]
294 fn test_transpose_2d_basic() {
295 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3])
296 .expect("Failed to create tensor");
297 let transposed = tensor.transpose(0, 1);
298
299 assert_eq!(transposed.shape().dims(), vec![3, 2]);
300
301 // Verify data layout: original [2,3] -> transposed [3,2]
302 // Original: [[1,2,3], [4,5,6]]
303 // Transposed: [[1,4], [2,5], [3,6]]
304 assert_eq!(transposed.get(&[0, 0]), 1.0);
305 assert_eq!(transposed.get(&[0, 1]), 4.0);
306 assert_eq!(transposed.get(&[1, 0]), 2.0);
307 assert_eq!(transposed.get(&[1, 1]), 5.0);
308 assert_eq!(transposed.get(&[2, 0]), 3.0);
309 assert_eq!(transposed.get(&[2, 1]), 6.0);
310 }
311
312 #[test]
313 fn test_matrix_transpose() {
314 let matrix =
315 Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).expect("Failed to create tensor");
316 let transposed = matrix.t();
317
318 assert_eq!(transposed.shape().dims(), vec![2, 2]);
319
320 // Original: [[1,2], [3,4]]
321 // Transposed: [[1,3], [2,4]]
322 assert_eq!(transposed.get(&[0, 0]), 1.0);
323 assert_eq!(transposed.get(&[0, 1]), 3.0);
324 assert_eq!(transposed.get(&[1, 0]), 2.0);
325 assert_eq!(transposed.get(&[1, 1]), 4.0);
326 }
327
328 #[test]
329 fn test_transpose_3d() {
330 let tensor = Tensor::new(vec![2, 3, 4]);
331 let transposed = tensor.transpose(0, 2);
332
333 // Shape changes from [2,3,4] to [4,3,2]
334 assert_eq!(transposed.shape().dims(), vec![4, 3, 2]);
335 }
336
337 #[test]
338 fn test_transpose_same_dimension() {
339 let tensor =
340 Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).expect("Failed to create tensor");
341 let result = tensor.transpose(1, 1);
342
343 // Should be identical to original
344 assert_eq!(result.shape().dims(), tensor.shape().dims());
345 for i in 0..2 {
346 for j in 0..2 {
347 assert_eq!(result.get(&[i, j]), tensor.get(&[i, j]));
348 }
349 }
350 }
351
352 #[test]
353 fn test_transpose_preserves_gradient_requirement() {
354 let mut tensor = Tensor::new(vec![2, 3]);
355 tensor.set_requires_grad(true);
356 let transposed = tensor.transpose(0, 1);
357
358 assert!(transposed.requires_grad());
359 }
360
361 #[test]
362 #[should_panic(expected = "out of bounds")]
363 fn test_transpose_invalid_dimension() {
364 let tensor = Tensor::new(vec![2, 3]);
365 tensor.transpose(0, 3); // Should panic: dim 3 out of bounds
366 }
367
368 #[test]
369 #[should_panic(expected = "Matrix transpose requires at least 2 dimensions")]
370 fn test_matrix_transpose_1d() {
371 let tensor = Tensor::new(vec![5]);
372 tensor.t(); // Should panic: 1D tensor
373 }
374
375 #[test]
376 fn test_transpose_large_tensor() {
377 // Test with larger tensor to exercise cache-optimized path
378 let tensor = Tensor::new(vec![32, 32]); // 1024 elements
379 let transposed = tensor.transpose(0, 1);
380
381 assert_eq!(transposed.shape().dims(), vec![32, 32]);
382 }
383
384 #[test]
385 fn test_transpose_memory_layout() {
386 let tensor = Tensor::new(vec![3, 4]);
387 assert!(tensor.is_contiguous());
388
389 let transposed = tensor.transpose(0, 1);
390 // After transpose, the result should still be valid but may not be contiguous
391 // depending on implementation (view vs copy)
392 assert_eq!(transposed.shape().dims(), vec![4, 3]);
393 }
394
395 #[test]
396 fn test_transpose_first_dimensions_3d() {
397 // Test the critical bug fix: transpose dimensions other than last two
398 let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
399 let tensor = Tensor::from_slice(&data, vec![2, 3, 4]).unwrap();
400
401 // Transpose the first two dimensions (not last two)
402 let transposed = tensor.transpose(0, 1);
403
404 // Shape should change from [2,3,4] to [3,2,4]
405 assert_eq!(transposed.shape().dims(), vec![3, 2, 4]);
406
407 // Verify data is correctly transposed
408 // Original: tensor[d0][d1][d2] where d0=2, d1=3, d2=4
409 // After transpose(0,1): tensor[d1][d0][d2] where d1=3, d0=2, d2=4
410
411 assert_eq!(transposed.get(&[0, 0, 0]), 0.0); // Maps to original [0,0,0]
412 assert_eq!(transposed.get(&[0, 1, 0]), 12.0); // Maps to original [1,0,0]
413 assert_eq!(transposed.get(&[1, 0, 0]), 4.0); // Maps to original [0,1,0]
414 assert_eq!(transposed.get(&[1, 1, 0]), 16.0); // Maps to original [1,1,0]
415 assert_eq!(transposed.get(&[2, 0, 0]), 8.0); // Maps to original [0,2,0]
416 assert_eq!(transposed.get(&[2, 1, 0]), 20.0); // Maps to original [1,2,0]
417 }
418}