train_station/tensor/transform/transpose.rs
1//! Tensor transpose operations
2//!
3//! This module provides tensor transpose functionality that swaps dimensions
4//! of tensors, effectively changing the memory access pattern and logical
5//! arrangement of data. Transposition is a fundamental tensor transformation
6//! operation used in machine learning for matrix operations, preparing data
7//! for specific layer types, and implementing complex tensor manipulations.
8//!
9//! # Operations
10//!
11//! * `transpose()` - Swap two specified dimensions of a tensor
12//! * `t()` - Matrix transpose (swap last two dimensions)
13//!
14//! # Performance Characteristics
15//!
16//! * **Zero-Copy Operations**: Returns a view when possible using stride manipulation
17//! * **Memory Efficient**: Reuses existing tensor data through view operations
18//! * **Cache Optimized**: Uses optimized copying when view operations are not possible
19//! * **Gradient Tracking**: Full GradTrack support for automatic differentiation
20//! * **Shape Transformation**: Changes dimension order while preserving total elements
21//!
22//! # Examples
23//!
24//! ```
25//! use train_station::Tensor;
26//!
27//! // Basic 2D transpose
28//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
29//! let transposed = tensor.transpose(0, 1);
30//! assert_eq!(transposed.shape().dims, vec![3, 2]);
31//! ```
32//!
33//! ```
34//! use train_station::Tensor;
35//!
36//! // Matrix transpose convenience method
37//! let matrix = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
38//! let transposed = matrix.t();
39//! assert_eq!(transposed.shape().dims, vec![2, 2]);
40//! ```
41//!
42//! # Gradient Tracking
43//!
44//! The transpose operations support automatic gradient tracking through
45//! the GradTrack system. When `requires_grad` is enabled, the operations
46//! register gradient functions that apply the inverse transpose during
47//! backward passes.
48
49use crate::tensor::core::Tensor;
50
51impl Tensor {
52 /// Transpose two dimensions of the tensor
53 ///
54 /// Swaps two specified dimensions of the tensor, modifying the shape
55 /// and memory access pattern. When possible, this operation returns
56 /// a zero-copy view using stride manipulation. For complex cases or
57 /// non-contiguous tensors, data is copied to ensure correct transposition.
58 ///
59 /// The transpose operation is its own inverse - applying transpose
60 /// twice with the same dimensions returns the original tensor.
61 ///
62 /// # Arguments
63 ///
64 /// * `dim0` - First dimension to swap (must be < tensor rank)
65 /// * `dim1` - Second dimension to swap (must be < tensor rank)
66 ///
67 /// # Returns
68 ///
69 /// A new tensor with the specified dimensions transposed. The total
70 /// number of elements remains unchanged.
71 ///
72 /// # Panics
73 ///
74 /// * If `dim0` is out of bounds for the tensor rank
75 /// * If `dim1` is out of bounds for the tensor rank
76 ///
77 /// # Examples
78 ///
79 /// ```
80 /// use train_station::Tensor;
81 ///
82 /// // Basic 2D transpose
83 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
84 /// let transposed = tensor.transpose(0, 1);
85 /// assert_eq!(transposed.shape().dims, vec![3, 2]);
86 /// assert_eq!(transposed.get(&[0, 0]), 1.0);
87 /// assert_eq!(transposed.get(&[0, 1]), 4.0);
88 /// assert_eq!(transposed.get(&[1, 0]), 2.0);
89 /// assert_eq!(transposed.get(&[1, 1]), 5.0);
90 /// assert_eq!(transposed.get(&[2, 0]), 3.0);
91 /// assert_eq!(transposed.get(&[2, 1]), 6.0);
92 /// ```
93 ///
94 /// ```
95 /// use train_station::Tensor;
96 ///
97 /// // 3D tensor transpose
98 /// let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
99 /// let tensor = Tensor::from_slice(&data, vec![2, 3, 4]).unwrap();
100 /// let transposed = tensor.transpose(0, 1);
101 /// assert_eq!(transposed.shape().dims, vec![3, 2, 4]);
102 /// ```
103 ///
104 /// ```
105 /// use train_station::Tensor;
106 ///
107 /// // Transpose with gradient tracking
108 /// let mut tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
109 /// tensor.set_requires_grad(true);
110 ///
111 /// let transposed = tensor.transpose(0, 1);
112 /// assert!(transposed.requires_grad());
113 /// assert_eq!(transposed.shape().dims, vec![2, 2]);
114 /// ```
115 ///
116 /// ```
117 /// use train_station::Tensor;
118 ///
119 /// // Transpose same dimension (no change)
120 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
121 /// let result = tensor.transpose(1, 1);
122 /// assert_eq!(result.shape().dims, tensor.shape().dims);
123 /// assert_eq!(result.get(&[0, 0]), tensor.get(&[0, 0]));
124 /// ```
125 ///
126 /// ```
127 /// use train_station::Tensor;
128 ///
129 /// // Transpose is its own inverse
130 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
131 /// let transposed = tensor.transpose(0, 1);
132 /// let double_transposed = transposed.transpose(0, 1);
133 /// assert_eq!(double_transposed.shape().dims, tensor.shape().dims);
134 /// assert_eq!(double_transposed.get(&[0, 0]), tensor.get(&[0, 0]));
135 /// ```
136 ///
137 /// # Performance
138 ///
139 /// - **Contiguous tensors**: O(1) time complexity, returns a view
140 /// - **Non-contiguous tensors**: O(n) time complexity with data copying
141 /// - **Memory usage**: No additional allocation for view operations
142 /// - **Gradient tracking**: Preserves gradient requirements and tracking
143 ///
144 /// # Relationship to Other Operations
145 ///
146 /// This operation is related to other tensor transformations:
147 /// - `t()` - Convenience method for matrix transpose (last two dimensions)
148 /// - `permute()` - More general dimension reordering operation
149 /// - `reshape()` - Changes shape without changing dimension order
150 ///
151 /// # Memory Layout
152 ///
153 /// For contiguous tensors, transpose returns a view with modified strides,
154 /// making the tensor non-contiguous. For non-contiguous tensors or complex
155 /// cases, data is copied to ensure correct transposition.
156 pub fn transpose(&self, dim0: usize, dim1: usize) -> Tensor {
157 assert!(
158 dim0 < self.shape().rank(),
159 "dim0 {} out of bounds for tensor with rank {}",
160 dim0,
161 self.shape().rank()
162 );
163 assert!(
164 dim1 < self.shape().rank(),
165 "dim1 {} out of bounds for tensor with rank {}",
166 dim1,
167 self.shape().rank()
168 );
169
170 // If same dimension, return a clone
171 if dim0 == dim1 {
172 return self.clone();
173 }
174
175 // Create new dimensions and strides by swapping
176 let mut new_dims = self.shape().dims.clone();
177 let mut new_strides = self.strides().to_vec();
178
179 new_dims.swap(dim0, dim1);
180 new_strides.swap(dim0, dim1);
181
182 // Create a view-based transpose when possible (creates non-contiguous tensor)
183 let mut result = if self.is_contiguous() && self.can_transpose_as_view(dim0, dim1) {
184 let new_shape = crate::tensor::Shape::as_view(new_dims, new_strides);
185 self.create_view_with_shape(new_shape)
186 } else {
187 // Fallback to copy for complex cases
188 self.transpose_with_copy(new_dims, new_strides, dim0, dim1)
189 };
190
191 // GradTrack: register transpose for backward (transpose is its own inverse)
192 if self.requires_grad() {
193 result.set_requires_grad(true);
194 let grad_fn = crate::gradtrack::grad_fn::GradFn::Transpose {
195 dim0,
196 dim1,
197 input_shape: self.shape().dims.clone(),
198 };
199 result.set_grad_fn(grad_fn.clone());
200 crate::gradtrack::engine::GradEngine::register_operation(
201 result.id(),
202 vec![self.id()],
203 grad_fn,
204 );
205 }
206
207 result
208 }
209
210 /// Matrix transpose (transpose last two dimensions)
211 ///
212 /// Convenience method for the common case of matrix transposition.
213 /// For 2D tensors, this performs a standard matrix transpose.
214 /// For higher-dimensional tensors, this transposes the last two
215 /// dimensions, treating the tensor as a batch of matrices.
216 ///
217 /// This method is equivalent to `transpose(rank-2, rank-1)` where
218 /// `rank` is the number of dimensions in the tensor.
219 ///
220 /// # Returns
221 ///
222 /// A new tensor with the last two dimensions transposed
223 ///
224 /// # Panics
225 ///
226 /// * If the tensor has less than 2 dimensions
227 ///
228 /// # Examples
229 ///
230 /// ```
231 /// use train_station::Tensor;
232 ///
233 /// // 2D matrix transpose
234 /// let matrix = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
235 /// let transposed = matrix.t();
236 /// assert_eq!(transposed.shape().dims, vec![2, 2]);
237 /// assert_eq!(transposed.get(&[0, 0]), 1.0);
238 /// assert_eq!(transposed.get(&[0, 1]), 3.0);
239 /// assert_eq!(transposed.get(&[1, 0]), 2.0);
240 /// assert_eq!(transposed.get(&[1, 1]), 4.0);
241 /// ```
242 ///
243 /// ```
244 /// use train_station::Tensor;
245 ///
246 /// // 3D tensor (batch of matrices)
247 /// let data: Vec<f32> = (0..12).map(|i| i as f32).collect();
248 /// let tensor = Tensor::from_slice(&data, vec![2, 2, 3]).unwrap();
249 /// let transposed = tensor.t();
250 /// assert_eq!(transposed.shape().dims, vec![2, 3, 2]);
251 /// ```
252 ///
253 /// ```
254 /// use train_station::Tensor;
255 ///
256 /// // Matrix transpose with gradient tracking
257 /// let mut matrix = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
258 /// matrix.set_requires_grad(true);
259 ///
260 /// let transposed = matrix.t();
261 /// assert!(transposed.requires_grad());
262 /// assert_eq!(transposed.shape().dims, vec![2, 2]);
263 /// ```
264 ///
265 /// # Performance
266 ///
267 /// - **Time Complexity**: Same as `transpose()` - O(1) for views, O(n) for copies
268 /// - **Memory Usage**: Same as `transpose()` - no allocation for views
269 /// - **Gradient Tracking**: Preserves gradient requirements and tracking
270 ///
271 /// # Relationship to Other Operations
272 ///
273 /// This operation is equivalent to:
274 /// ```rust
275 /// use train_station::Tensor;
276 ///
277 /// let tensor = Tensor::new(vec![2, 3, 4]);
278 /// let rank = tensor.shape().rank();
279 /// let transposed1 = tensor.t();
280 /// let transposed2 = tensor.transpose(rank - 2, rank - 1);
281 /// // transposed1 and transposed2 are identical
282 /// ```
283 pub fn t(&self) -> Tensor {
284 assert!(
285 self.shape().rank() >= 2,
286 "Matrix transpose requires at least 2 dimensions, got {}",
287 self.shape().rank()
288 );
289 let rank = self.shape().rank();
290 self.transpose(rank - 2, rank - 1)
291 }
292
293 /// Check if transpose can be done as a zero-copy view operation
294 ///
295 /// Determines whether the transpose operation can be performed as a
296 /// zero-copy view by manipulating strides rather than copying data.
297 /// This is possible for contiguous tensors when swapping different dimensions.
298 ///
299 /// # Arguments
300 ///
301 /// * `dim0` - First dimension to swap
302 /// * `dim1` - Second dimension to swap
303 ///
304 /// # Returns
305 ///
306 /// `true` if the transpose can be done as a view (zero-copy), `false` otherwise
307 ///
308 /// # Performance
309 ///
310 /// - **Time Complexity**: O(1) - Simple boolean checks
311 /// - **Memory Usage**: No allocation
312 ///
313 /// # Examples
314 ///
315 /// This method is used internally by the `transpose()` function to
316 /// determine the optimal implementation strategy (view vs copy).
317 fn can_transpose_as_view(&self, dim0: usize, dim1: usize) -> bool {
318 // For contiguous tensors, we can always create a view with different strides
319 // This is safe because we're not modifying the underlying data, just the access pattern
320 self.is_contiguous() && (dim0 != dim1)
321 }
322
323 /// Transpose with data copying when view operation is not possible
324 ///
325 /// Performs transpose by copying data to a new tensor when a view-based
326 /// transpose is not possible or optimal. This method ensures correct
327 /// transposition for all tensor types and memory layouts.
328 ///
329 /// # Arguments
330 ///
331 /// * `new_dims` - The new dimensions after transposition
332 /// * `_new_strides` - The new strides after transposition (unused in copy implementation)
333 /// * `dim0` - First dimension that was swapped
334 /// * `dim1` - Second dimension that was swapped
335 ///
336 /// # Returns
337 ///
338 /// A new tensor with copied and transposed data
339 ///
340 /// # Performance
341 ///
342 /// - **Time Complexity**: O(n) where n is the number of elements
343 /// - **Memory Usage**: Allocates new tensor with same total size
344 /// - **Data Integrity**: Ensures correct transposition for all cases
345 ///
346 /// # Examples
347 ///
348 /// This method is called internally by `transpose()` when view-based
349 /// transposition is not possible, such as for non-contiguous tensors
350 /// or complex memory layouts.
351 fn transpose_with_copy(
352 &self,
353 new_dims: Vec<usize>,
354 _new_strides: Vec<usize>,
355 dim0: usize,
356 dim1: usize,
357 ) -> Tensor {
358 let mut result = Tensor::new(new_dims.clone());
359
360 // Use stride-aware copying that correctly handles arbitrary dimension swaps
361 unsafe {
362 self.transpose_copy_stride_aware(&mut result, dim0, dim1);
363 }
364
365 // Preserve gradient tracking requirement
366 if self.requires_grad() {
367 result.set_requires_grad(true);
368 }
369
370 result
371 }
372
373 /// Stride-aware transpose copy that correctly handles arbitrary dimension swaps
374 ///
375 /// Performs efficient transpose copying using coordinate mapping and
376 /// stride calculations. This method correctly handles transposition
377 /// of any two dimensions in tensors of arbitrary rank and shape.
378 ///
379 /// # Arguments
380 ///
381 /// * `result` - Output tensor to write transposed data
382 /// * `dim0` - First dimension that was swapped
383 /// * `dim1` - Second dimension that was swapped
384 ///
385 /// # Safety
386 ///
387 /// This function uses unsafe pointer arithmetic for performance.
388 /// The caller must ensure:
389 /// * `result` tensor has the correct size and shape
390 /// * `result` tensor is properly allocated and accessible
391 /// * `dim0` and `dim1` are valid dimension indices
392 /// * Source tensor data is valid and accessible
393 ///
394 /// # Performance
395 ///
396 /// - **Time Complexity**: O(n) where n is the number of elements
397 /// - **Memory Access**: Optimized for cache-friendly access patterns
398 /// - **Coordinate Mapping**: Efficient conversion between linear and multi-dimensional indices
399 /// - **Bounds Checking**: Debug assertions for safety in debug builds
400 ///
401 /// # Examples
402 ///
403 /// This method is used internally by `transpose_with_copy()` to perform
404 /// the actual data copying with correct coordinate mapping for arbitrary
405 /// dimension swaps.
406 unsafe fn transpose_copy_stride_aware(&self, result: &mut Tensor, dim0: usize, dim1: usize) {
407 let src_ptr = self.as_ptr();
408 let dst_ptr = result.as_mut_ptr();
409
410 // Iterate through all elements of the result tensor
411 for dst_idx in 0..result.size() {
412 // Convert linear index to multi-dimensional coordinates for result
413 let mut dst_coords = Vec::new();
414 let mut temp_idx = dst_idx;
415
416 for &dim_size in result.shape().dims.iter().rev() {
417 dst_coords.push(temp_idx % dim_size);
418 temp_idx /= dim_size;
419 }
420 dst_coords.reverse();
421
422 // Map result coordinates to source coordinates (reverse the transpose)
423 let mut src_coords = dst_coords.clone();
424 src_coords.swap(dim0, dim1);
425
426 // Calculate source offset using strides
427 let src_offset = self.shape().offset(&src_coords);
428
429 // Bounds check to prevent buffer overruns
430 debug_assert!(
431 src_offset < self.size(),
432 "Source offset {} out of bounds for tensor size {}",
433 src_offset,
434 self.size()
435 );
436 debug_assert!(
437 dst_idx < result.size(),
438 "Destination index {} out of bounds for result size {}",
439 dst_idx,
440 result.size()
441 );
442
443 // Copy element
444 *dst_ptr.add(dst_idx) = *src_ptr.add(src_offset);
445 }
446 }
447}
448
449#[cfg(test)]
450mod tests {
451 use super::*;
452
453 #[test]
454 fn test_transpose_2d_basic() {
455 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3])
456 .expect("Failed to create tensor");
457 let transposed = tensor.transpose(0, 1);
458
459 assert_eq!(transposed.shape().dims, vec![3, 2]);
460
461 // Verify data layout: original [2,3] -> transposed [3,2]
462 // Original: [[1,2,3], [4,5,6]]
463 // Transposed: [[1,4], [2,5], [3,6]]
464 assert_eq!(transposed.get(&[0, 0]), 1.0);
465 assert_eq!(transposed.get(&[0, 1]), 4.0);
466 assert_eq!(transposed.get(&[1, 0]), 2.0);
467 assert_eq!(transposed.get(&[1, 1]), 5.0);
468 assert_eq!(transposed.get(&[2, 0]), 3.0);
469 assert_eq!(transposed.get(&[2, 1]), 6.0);
470 }
471
472 #[test]
473 fn test_matrix_transpose() {
474 let matrix =
475 Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).expect("Failed to create tensor");
476 let transposed = matrix.t();
477
478 assert_eq!(transposed.shape().dims, vec![2, 2]);
479
480 // Original: [[1,2], [3,4]]
481 // Transposed: [[1,3], [2,4]]
482 assert_eq!(transposed.get(&[0, 0]), 1.0);
483 assert_eq!(transposed.get(&[0, 1]), 3.0);
484 assert_eq!(transposed.get(&[1, 0]), 2.0);
485 assert_eq!(transposed.get(&[1, 1]), 4.0);
486 }
487
488 #[test]
489 fn test_transpose_3d() {
490 let tensor = Tensor::new(vec![2, 3, 4]);
491 let transposed = tensor.transpose(0, 2);
492
493 // Shape changes from [2,3,4] to [4,3,2]
494 assert_eq!(transposed.shape().dims, vec![4, 3, 2]);
495 }
496
497 #[test]
498 fn test_transpose_same_dimension() {
499 let tensor =
500 Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).expect("Failed to create tensor");
501 let result = tensor.transpose(1, 1);
502
503 // Should be identical to original
504 assert_eq!(result.shape().dims, tensor.shape().dims);
505 for i in 0..2 {
506 for j in 0..2 {
507 assert_eq!(result.get(&[i, j]), tensor.get(&[i, j]));
508 }
509 }
510 }
511
512 #[test]
513 fn test_transpose_preserves_gradient_requirement() {
514 let mut tensor = Tensor::new(vec![2, 3]);
515 tensor.set_requires_grad(true);
516 let transposed = tensor.transpose(0, 1);
517
518 assert!(transposed.requires_grad());
519 }
520
521 #[test]
522 #[should_panic(expected = "out of bounds")]
523 fn test_transpose_invalid_dimension() {
524 let tensor = Tensor::new(vec![2, 3]);
525 tensor.transpose(0, 3); // Should panic: dim 3 out of bounds
526 }
527
528 #[test]
529 #[should_panic(expected = "Matrix transpose requires at least 2 dimensions")]
530 fn test_matrix_transpose_1d() {
531 let tensor = Tensor::new(vec![5]);
532 tensor.t(); // Should panic: 1D tensor
533 }
534
535 #[test]
536 fn test_transpose_large_tensor() {
537 // Test with larger tensor to exercise cache-optimized path
538 let tensor = Tensor::new(vec![32, 32]); // 1024 elements
539 let transposed = tensor.transpose(0, 1);
540
541 assert_eq!(transposed.shape().dims, vec![32, 32]);
542 }
543
544 #[test]
545 fn test_transpose_memory_layout() {
546 let tensor = Tensor::new(vec![3, 4]);
547 assert!(tensor.is_contiguous());
548
549 let transposed = tensor.transpose(0, 1);
550 // After transpose, the result should still be valid but may not be contiguous
551 // depending on implementation (view vs copy)
552 assert_eq!(transposed.shape().dims, vec![4, 3]);
553 }
554
555 #[test]
556 fn test_transpose_first_dimensions_3d() {
557 // Test the critical bug fix: transpose dimensions other than last two
558 let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
559 let tensor = Tensor::from_slice(&data, vec![2, 3, 4]).unwrap();
560
561 // Transpose the first two dimensions (not last two)
562 let transposed = tensor.transpose(0, 1);
563
564 // Shape should change from [2,3,4] to [3,2,4]
565 assert_eq!(transposed.shape().dims, vec![3, 2, 4]);
566
567 // Verify data is correctly transposed
568 // Original: tensor[d0][d1][d2] where d0=2, d1=3, d2=4
569 // After transpose(0,1): tensor[d1][d0][d2] where d1=3, d0=2, d2=4
570
571 assert_eq!(transposed.get(&[0, 0, 0]), 0.0); // Maps to original [0,0,0]
572 assert_eq!(transposed.get(&[0, 1, 0]), 12.0); // Maps to original [1,0,0]
573 assert_eq!(transposed.get(&[1, 0, 0]), 4.0); // Maps to original [0,1,0]
574 assert_eq!(transposed.get(&[1, 1, 0]), 16.0); // Maps to original [1,1,0]
575 assert_eq!(transposed.get(&[2, 0, 0]), 8.0); // Maps to original [0,2,0]
576 assert_eq!(transposed.get(&[2, 1, 0]), 20.0); // Maps to original [1,2,0]
577 }
578}