train_station/tensor/transform/transpose.rs
1//! Tensor transpose operations
2//!
3//! This module provides tensor transpose functionality that swaps dimensions
4//! of tensors, effectively changing the memory access pattern and logical
5//! arrangement of data. Transposition is a fundamental tensor transformation
6//! operation used in machine learning for matrix operations, preparing data
7//! for specific layer types, and implementing complex tensor manipulations.
8//!
9//! # Operations
10//!
11//! * `transpose()` - Swap two specified dimensions of a tensor
12//! * `t()` - Matrix transpose (swap last two dimensions)
13//!
14//! # Performance Characteristics
15//!
16//! * **Zero-Copy Operations**: Returns a view when possible using stride manipulation
17//! * **Memory Efficient**: Reuses existing tensor data through view operations
18//! * **Cache Optimized**: Uses optimized copying when view operations are not possible
19//! * **Gradient Tracking**: Full GradTrack support for automatic differentiation
20//! * **Shape Transformation**: Changes dimension order while preserving total elements
21//!
22//! # Examples
23//!
24//! ```
25//! use train_station::Tensor;
26//!
27//! // Basic 2D transpose
28//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
29//! let transposed = tensor.transpose(0, 1);
30//! assert_eq!(transposed.shape().dims, vec![3, 2]);
31//! ```
32//!
33//! ```
34//! use train_station::Tensor;
35//!
36//! // Matrix transpose convenience method
37//! let matrix = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
38//! let transposed = matrix.t();
39//! assert_eq!(transposed.shape().dims, vec![2, 2]);
40//! ```
41//!
42//! # Gradient Tracking
43//!
44//! The transpose operations support automatic gradient tracking through
45//! the GradTrack system. When `requires_grad` is enabled, the operations
46//! register gradient functions that apply the inverse transpose during
47//! backward passes.
48
49use crate::tensor::core::Tensor;
50
51impl Tensor {
52 /// Transpose two dimensions of the tensor
53 ///
54 /// Swaps two specified dimensions of the tensor, modifying the shape
55 /// and memory access pattern. When possible, this operation returns
56 /// a zero-copy view using stride manipulation. For complex cases or
57 /// non-contiguous tensors, data is copied to ensure correct transposition.
58 ///
59 /// The transpose operation is its own inverse - applying transpose
60 /// twice with the same dimensions returns the original tensor.
61 ///
62 /// # Arguments
63 ///
64 /// * `dim0` - First dimension to swap (must be < tensor rank)
65 /// * `dim1` - Second dimension to swap (must be < tensor rank)
66 ///
67 /// # Returns
68 ///
69 /// A new tensor with the specified dimensions transposed. The total
70 /// number of elements remains unchanged.
71 ///
72 /// # Panics
73 ///
74 /// * If `dim0` is out of bounds for the tensor rank
75 /// * If `dim1` is out of bounds for the tensor rank
76 ///
77 /// # Examples
78 ///
79 /// ```
80 /// use train_station::Tensor;
81 ///
82 /// // Basic 2D transpose
83 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
84 /// let transposed = tensor.transpose(0, 1);
85 /// assert_eq!(transposed.shape().dims, vec![3, 2]);
86 /// assert_eq!(transposed.get(&[0, 0]), 1.0);
87 /// assert_eq!(transposed.get(&[0, 1]), 4.0);
88 /// assert_eq!(transposed.get(&[1, 0]), 2.0);
89 /// assert_eq!(transposed.get(&[1, 1]), 5.0);
90 /// assert_eq!(transposed.get(&[2, 0]), 3.0);
91 /// assert_eq!(transposed.get(&[2, 1]), 6.0);
92 /// ```
93 ///
94 /// ```
95 /// use train_station::Tensor;
96 ///
97 /// // 3D tensor transpose
98 /// let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
99 /// let tensor = Tensor::from_slice(&data, vec![2, 3, 4]).unwrap();
100 /// let transposed = tensor.transpose(0, 1);
101 /// assert_eq!(transposed.shape().dims, vec![3, 2, 4]);
102 /// ```
103 ///
104 /// ```
105 /// use train_station::Tensor;
106 ///
107 /// // Transpose with gradient tracking
108 /// let mut tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
109 /// tensor.set_requires_grad(true);
110 ///
111 /// let transposed = tensor.transpose(0, 1);
112 /// assert!(transposed.requires_grad());
113 /// assert_eq!(transposed.shape().dims, vec![2, 2]);
114 /// ```
115 ///
116 /// ```
117 /// use train_station::Tensor;
118 ///
119 /// // Transpose same dimension (no change)
120 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
121 /// let result = tensor.transpose(1, 1);
122 /// assert_eq!(result.shape().dims, tensor.shape().dims);
123 /// assert_eq!(result.get(&[0, 0]), tensor.get(&[0, 0]));
124 /// ```
125 ///
126 /// ```
127 /// use train_station::Tensor;
128 ///
129 /// // Transpose is its own inverse
130 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
131 /// let transposed = tensor.transpose(0, 1);
132 /// let double_transposed = transposed.transpose(0, 1);
133 /// assert_eq!(double_transposed.shape().dims, tensor.shape().dims);
134 /// assert_eq!(double_transposed.get(&[0, 0]), tensor.get(&[0, 0]));
135 /// ```
136 ///
137 /// # Performance
138 ///
139 /// - **Contiguous tensors**: O(1) time complexity, returns a view
140 /// - **Non-contiguous tensors**: O(n) time complexity with data copying
141 /// - **Memory usage**: No additional allocation for view operations
142 /// - **Gradient tracking**: Preserves gradient requirements and tracking
143 ///
144 /// # Relationship to Other Operations
145 ///
146 /// This operation is related to other tensor transformations:
147 /// - `t()` - Convenience method for matrix transpose (last two dimensions)
148 /// - `permute()` - More general dimension reordering operation
149 /// - `reshape()` - Changes shape without changing dimension order
150 ///
151 /// # Memory Layout
152 ///
153 /// For contiguous tensors, transpose returns a view with modified strides,
154 /// making the tensor non-contiguous. For non-contiguous tensors or complex
155 /// cases, data is copied to ensure correct transposition.
156 #[track_caller]
157 pub fn transpose(&self, dim0: usize, dim1: usize) -> Tensor {
158 assert!(
159 dim0 < self.shape().rank(),
160 "dim0 {} out of bounds for tensor with rank {}",
161 dim0,
162 self.shape().rank()
163 );
164 assert!(
165 dim1 < self.shape().rank(),
166 "dim1 {} out of bounds for tensor with rank {}",
167 dim1,
168 self.shape().rank()
169 );
170
171 // If same dimension, return a clone
172 if dim0 == dim1 {
173 return self.clone();
174 }
175
176 // Create new dimensions and strides by swapping
177 let mut new_dims = self.shape().dims.clone();
178 let mut new_strides = self.strides().to_vec();
179
180 new_dims.swap(dim0, dim1);
181 new_strides.swap(dim0, dim1);
182
183 // Create a view-based transpose when possible (creates non-contiguous tensor)
184 let mut result = if self.is_contiguous() && self.can_transpose_as_view(dim0, dim1) {
185 let new_shape = crate::tensor::Shape::as_view(new_dims, new_strides);
186 self.create_view_with_shape(new_shape)
187 } else {
188 // Fallback to copy for complex cases
189 self.transpose_with_copy(new_dims, new_strides, dim0, dim1)
190 };
191
192 // GradTrack: register transpose for backward (transpose is its own inverse)
193 if self.requires_grad() {
194 result.set_requires_grad(true);
195 let grad_fn = crate::gradtrack::grad_fn::GradFn::Transpose {
196 dim0,
197 dim1,
198 input_shape: self.shape().dims.clone(),
199 };
200 result.set_grad_fn(grad_fn.clone());
201 crate::gradtrack::engine::GradEngine::register_operation(
202 result.id(),
203 vec![self.id()],
204 grad_fn,
205 );
206 }
207
208 result
209 }
210
211 /// Matrix transpose (transpose last two dimensions)
212 ///
213 /// Convenience method for the common case of matrix transposition.
214 /// For 2D tensors, this performs a standard matrix transpose.
215 /// For higher-dimensional tensors, this transposes the last two
216 /// dimensions, treating the tensor as a batch of matrices.
217 ///
218 /// This method is equivalent to `transpose(rank-2, rank-1)` where
219 /// `rank` is the number of dimensions in the tensor.
220 ///
221 /// # Returns
222 ///
223 /// A new tensor with the last two dimensions transposed
224 ///
225 /// # Panics
226 ///
227 /// * If the tensor has less than 2 dimensions
228 ///
229 /// # Examples
230 ///
231 /// ```
232 /// use train_station::Tensor;
233 ///
234 /// // 2D matrix transpose
235 /// let matrix = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
236 /// let transposed = matrix.t();
237 /// assert_eq!(transposed.shape().dims, vec![2, 2]);
238 /// assert_eq!(transposed.get(&[0, 0]), 1.0);
239 /// assert_eq!(transposed.get(&[0, 1]), 3.0);
240 /// assert_eq!(transposed.get(&[1, 0]), 2.0);
241 /// assert_eq!(transposed.get(&[1, 1]), 4.0);
242 /// ```
243 ///
244 /// ```
245 /// use train_station::Tensor;
246 ///
247 /// // 3D tensor (batch of matrices)
248 /// let data: Vec<f32> = (0..12).map(|i| i as f32).collect();
249 /// let tensor = Tensor::from_slice(&data, vec![2, 2, 3]).unwrap();
250 /// let transposed = tensor.t();
251 /// assert_eq!(transposed.shape().dims, vec![2, 3, 2]);
252 /// ```
253 ///
254 /// ```
255 /// use train_station::Tensor;
256 ///
257 /// // Matrix transpose with gradient tracking
258 /// let mut matrix = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
259 /// matrix.set_requires_grad(true);
260 ///
261 /// let transposed = matrix.t();
262 /// assert!(transposed.requires_grad());
263 /// assert_eq!(transposed.shape().dims, vec![2, 2]);
264 /// ```
265 ///
266 /// # Performance
267 ///
268 /// - **Time Complexity**: Same as `transpose()` - O(1) for views, O(n) for copies
269 /// - **Memory Usage**: Same as `transpose()` - no allocation for views
270 /// - **Gradient Tracking**: Preserves gradient requirements and tracking
271 ///
272 /// # Relationship to Other Operations
273 ///
274 /// This operation is equivalent to:
275 /// ```rust
276 /// use train_station::Tensor;
277 ///
278 /// let tensor = Tensor::new(vec![2, 3, 4]);
279 /// let rank = tensor.shape().rank();
280 /// let transposed1 = tensor.t();
281 /// let transposed2 = tensor.transpose(rank - 2, rank - 1);
282 /// // transposed1 and transposed2 are identical
283 /// ```
284 #[track_caller]
285 pub fn t(&self) -> Tensor {
286 assert!(
287 self.shape().rank() >= 2,
288 "Matrix transpose requires at least 2 dimensions, got {}",
289 self.shape().rank()
290 );
291 let rank = self.shape().rank();
292 self.transpose(rank - 2, rank - 1)
293 }
294
295 /// Check if transpose can be done as a zero-copy view operation
296 ///
297 /// Determines whether the transpose operation can be performed as a
298 /// zero-copy view by manipulating strides rather than copying data.
299 /// This is possible for contiguous tensors when swapping different dimensions.
300 ///
301 /// # Arguments
302 ///
303 /// * `dim0` - First dimension to swap
304 /// * `dim1` - Second dimension to swap
305 ///
306 /// # Returns
307 ///
308 /// `true` if the transpose can be done as a view (zero-copy), `false` otherwise
309 ///
310 /// # Performance
311 ///
312 /// - **Time Complexity**: O(1) - Simple boolean checks
313 /// - **Memory Usage**: No allocation
314 ///
315 /// # Examples
316 ///
317 /// This method is used internally by the `transpose()` function to
318 /// determine the optimal implementation strategy (view vs copy).
319 fn can_transpose_as_view(&self, dim0: usize, dim1: usize) -> bool {
320 // For contiguous tensors, we can always create a view with different strides
321 // This is safe because we're not modifying the underlying data, just the access pattern
322 self.is_contiguous() && (dim0 != dim1)
323 }
324
325 /// Transpose with data copying when view operation is not possible
326 ///
327 /// Performs transpose by copying data to a new tensor when a view-based
328 /// transpose is not possible or optimal. This method ensures correct
329 /// transposition for all tensor types and memory layouts.
330 ///
331 /// # Arguments
332 ///
333 /// * `new_dims` - The new dimensions after transposition
334 /// * `_new_strides` - The new strides after transposition (unused in copy implementation)
335 /// * `dim0` - First dimension that was swapped
336 /// * `dim1` - Second dimension that was swapped
337 ///
338 /// # Returns
339 ///
340 /// A new tensor with copied and transposed data
341 ///
342 /// # Performance
343 ///
344 /// - **Time Complexity**: O(n) where n is the number of elements
345 /// - **Memory Usage**: Allocates new tensor with same total size
346 /// - **Data Integrity**: Ensures correct transposition for all cases
347 ///
348 /// # Examples
349 ///
350 /// This method is called internally by `transpose()` when view-based
351 /// transposition is not possible, such as for non-contiguous tensors
352 /// or complex memory layouts.
353 fn transpose_with_copy(
354 &self,
355 new_dims: Vec<usize>,
356 _new_strides: Vec<usize>,
357 dim0: usize,
358 dim1: usize,
359 ) -> Tensor {
360 let mut result = Tensor::new(new_dims.clone());
361
362 // Use stride-aware copying that correctly handles arbitrary dimension swaps
363 unsafe {
364 self.transpose_copy_stride_aware(&mut result, dim0, dim1);
365 }
366
367 // Preserve gradient tracking requirement
368 if self.requires_grad() {
369 result.set_requires_grad(true);
370 }
371
372 result
373 }
374
375 /// Stride-aware transpose copy that correctly handles arbitrary dimension swaps
376 ///
377 /// Performs efficient transpose copying using coordinate mapping and
378 /// stride calculations. This method correctly handles transposition
379 /// of any two dimensions in tensors of arbitrary rank and shape.
380 ///
381 /// # Arguments
382 ///
383 /// * `result` - Output tensor to write transposed data
384 /// * `dim0` - First dimension that was swapped
385 /// * `dim1` - Second dimension that was swapped
386 ///
387 /// # Safety
388 ///
389 /// This function uses unsafe pointer arithmetic for performance.
390 /// The caller must ensure:
391 /// * `result` tensor has the correct size and shape
392 /// * `result` tensor is properly allocated and accessible
393 /// * `dim0` and `dim1` are valid dimension indices
394 /// * Source tensor data is valid and accessible
395 ///
396 /// # Performance
397 ///
398 /// - **Time Complexity**: O(n) where n is the number of elements
399 /// - **Memory Access**: Optimized for cache-friendly access patterns
400 /// - **Coordinate Mapping**: Efficient conversion between linear and multi-dimensional indices
401 /// - **Bounds Checking**: Debug assertions for safety in debug builds
402 ///
403 /// # Examples
404 ///
405 /// This method is used internally by `transpose_with_copy()` to perform
406 /// the actual data copying with correct coordinate mapping for arbitrary
407 /// dimension swaps.
408 unsafe fn transpose_copy_stride_aware(&self, result: &mut Tensor, dim0: usize, dim1: usize) {
409 let src_ptr = self.as_ptr();
410 let dst_ptr = result.as_mut_ptr();
411
412 // Iterate through all elements of the result tensor
413 for dst_idx in 0..result.size() {
414 // Convert linear index to multi-dimensional coordinates for result
415 let mut dst_coords = Vec::new();
416 let mut temp_idx = dst_idx;
417
418 for &dim_size in result.shape().dims.iter().rev() {
419 dst_coords.push(temp_idx % dim_size);
420 temp_idx /= dim_size;
421 }
422 dst_coords.reverse();
423
424 // Map result coordinates to source coordinates (reverse the transpose)
425 let mut src_coords = dst_coords.clone();
426 src_coords.swap(dim0, dim1);
427
428 // Calculate source offset using strides
429 let src_offset = self.shape().offset(&src_coords);
430
431 // Bounds check to prevent buffer overruns
432 debug_assert!(
433 src_offset < self.size(),
434 "Source offset {} out of bounds for tensor size {}",
435 src_offset,
436 self.size()
437 );
438 debug_assert!(
439 dst_idx < result.size(),
440 "Destination index {} out of bounds for result size {}",
441 dst_idx,
442 result.size()
443 );
444
445 // Copy element
446 *dst_ptr.add(dst_idx) = *src_ptr.add(src_offset);
447 }
448 }
449}
450
451#[cfg(test)]
452mod tests {
453 use super::*;
454
455 #[test]
456 fn test_transpose_2d_basic() {
457 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3])
458 .expect("Failed to create tensor");
459 let transposed = tensor.transpose(0, 1);
460
461 assert_eq!(transposed.shape().dims, vec![3, 2]);
462
463 // Verify data layout: original [2,3] -> transposed [3,2]
464 // Original: [[1,2,3], [4,5,6]]
465 // Transposed: [[1,4], [2,5], [3,6]]
466 assert_eq!(transposed.get(&[0, 0]), 1.0);
467 assert_eq!(transposed.get(&[0, 1]), 4.0);
468 assert_eq!(transposed.get(&[1, 0]), 2.0);
469 assert_eq!(transposed.get(&[1, 1]), 5.0);
470 assert_eq!(transposed.get(&[2, 0]), 3.0);
471 assert_eq!(transposed.get(&[2, 1]), 6.0);
472 }
473
474 #[test]
475 fn test_matrix_transpose() {
476 let matrix =
477 Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).expect("Failed to create tensor");
478 let transposed = matrix.t();
479
480 assert_eq!(transposed.shape().dims, vec![2, 2]);
481
482 // Original: [[1,2], [3,4]]
483 // Transposed: [[1,3], [2,4]]
484 assert_eq!(transposed.get(&[0, 0]), 1.0);
485 assert_eq!(transposed.get(&[0, 1]), 3.0);
486 assert_eq!(transposed.get(&[1, 0]), 2.0);
487 assert_eq!(transposed.get(&[1, 1]), 4.0);
488 }
489
490 #[test]
491 fn test_transpose_3d() {
492 let tensor = Tensor::new(vec![2, 3, 4]);
493 let transposed = tensor.transpose(0, 2);
494
495 // Shape changes from [2,3,4] to [4,3,2]
496 assert_eq!(transposed.shape().dims, vec![4, 3, 2]);
497 }
498
499 #[test]
500 fn test_transpose_same_dimension() {
501 let tensor =
502 Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).expect("Failed to create tensor");
503 let result = tensor.transpose(1, 1);
504
505 // Should be identical to original
506 assert_eq!(result.shape().dims, tensor.shape().dims);
507 for i in 0..2 {
508 for j in 0..2 {
509 assert_eq!(result.get(&[i, j]), tensor.get(&[i, j]));
510 }
511 }
512 }
513
514 #[test]
515 fn test_transpose_preserves_gradient_requirement() {
516 let mut tensor = Tensor::new(vec![2, 3]);
517 tensor.set_requires_grad(true);
518 let transposed = tensor.transpose(0, 1);
519
520 assert!(transposed.requires_grad());
521 }
522
523 #[test]
524 #[should_panic(expected = "out of bounds")]
525 fn test_transpose_invalid_dimension() {
526 let tensor = Tensor::new(vec![2, 3]);
527 tensor.transpose(0, 3); // Should panic: dim 3 out of bounds
528 }
529
530 #[test]
531 #[should_panic(expected = "Matrix transpose requires at least 2 dimensions")]
532 fn test_matrix_transpose_1d() {
533 let tensor = Tensor::new(vec![5]);
534 tensor.t(); // Should panic: 1D tensor
535 }
536
537 #[test]
538 fn test_transpose_large_tensor() {
539 // Test with larger tensor to exercise cache-optimized path
540 let tensor = Tensor::new(vec![32, 32]); // 1024 elements
541 let transposed = tensor.transpose(0, 1);
542
543 assert_eq!(transposed.shape().dims, vec![32, 32]);
544 }
545
546 #[test]
547 fn test_transpose_memory_layout() {
548 let tensor = Tensor::new(vec![3, 4]);
549 assert!(tensor.is_contiguous());
550
551 let transposed = tensor.transpose(0, 1);
552 // After transpose, the result should still be valid but may not be contiguous
553 // depending on implementation (view vs copy)
554 assert_eq!(transposed.shape().dims, vec![4, 3]);
555 }
556
557 #[test]
558 fn test_transpose_first_dimensions_3d() {
559 // Test the critical bug fix: transpose dimensions other than last two
560 let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
561 let tensor = Tensor::from_slice(&data, vec![2, 3, 4]).unwrap();
562
563 // Transpose the first two dimensions (not last two)
564 let transposed = tensor.transpose(0, 1);
565
566 // Shape should change from [2,3,4] to [3,2,4]
567 assert_eq!(transposed.shape().dims, vec![3, 2, 4]);
568
569 // Verify data is correctly transposed
570 // Original: tensor[d0][d1][d2] where d0=2, d1=3, d2=4
571 // After transpose(0,1): tensor[d1][d0][d2] where d1=3, d0=2, d2=4
572
573 assert_eq!(transposed.get(&[0, 0, 0]), 0.0); // Maps to original [0,0,0]
574 assert_eq!(transposed.get(&[0, 1, 0]), 12.0); // Maps to original [1,0,0]
575 assert_eq!(transposed.get(&[1, 0, 0]), 4.0); // Maps to original [0,1,0]
576 assert_eq!(transposed.get(&[1, 1, 0]), 16.0); // Maps to original [1,1,0]
577 assert_eq!(transposed.get(&[2, 0, 0]), 8.0); // Maps to original [0,2,0]
578 assert_eq!(transposed.get(&[2, 1, 0]), 20.0); // Maps to original [1,2,0]
579 }
580}