train_station/tensor/transform/
transpose.rs

1//! Tensor transpose operations
2//!
3//! This module provides tensor transpose functionality that swaps dimensions
4//! of tensors, effectively changing the memory access pattern and logical
5//! arrangement of data. Transposition is a fundamental tensor transformation
6//! operation used in machine learning for matrix operations, preparing data
7//! for specific layer types, and implementing complex tensor manipulations.
8//!
9//! # Operations
10//!
11//! * `transpose()` - Swap two specified dimensions of a tensor
12//! * `t()` - Matrix transpose (swap last two dimensions)
13//!
14//! # Performance Characteristics
15//!
16//! * **Zero-Copy Operations**: Returns a view when possible using stride manipulation
17//! * **Memory Efficient**: Reuses existing tensor data through view operations
18//! * **Cache Optimized**: Uses optimized copying when view operations are not possible
19//! * **Gradient Tracking**: Full GradTrack support for automatic differentiation
20//! * **Shape Transformation**: Changes dimension order while preserving total elements
21//!
22//! # Examples
23//!
24//! ```
25//! use train_station::Tensor;
26//!
27//! // Basic 2D transpose
28//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
29//! let transposed = tensor.transpose(0, 1);
30//! assert_eq!(transposed.shape().dims, vec![3, 2]);
31//! ```
32//!
33//! ```
34//! use train_station::Tensor;
35//!
36//! // Matrix transpose convenience method
37//! let matrix = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
38//! let transposed = matrix.t();
39//! assert_eq!(transposed.shape().dims, vec![2, 2]);
40//! ```
41//!
42//! # Gradient Tracking
43//!
44//! The transpose operations support automatic gradient tracking through
45//! the GradTrack system. When `requires_grad` is enabled, the operations
46//! register gradient functions that apply the inverse transpose during
47//! backward passes.
48
49use crate::tensor::core::Tensor;
50
51impl Tensor {
52    /// Transpose two dimensions of the tensor
53    ///
54    /// Swaps two specified dimensions of the tensor, modifying the shape
55    /// and memory access pattern. When possible, this operation returns
56    /// a zero-copy view using stride manipulation. For complex cases or
57    /// non-contiguous tensors, data is copied to ensure correct transposition.
58    ///
59    /// The transpose operation is its own inverse - applying transpose
60    /// twice with the same dimensions returns the original tensor.
61    ///
62    /// # Arguments
63    ///
64    /// * `dim0` - First dimension to swap (must be < tensor rank)
65    /// * `dim1` - Second dimension to swap (must be < tensor rank)
66    ///
67    /// # Returns
68    ///
69    /// A new tensor with the specified dimensions transposed. The total
70    /// number of elements remains unchanged.
71    ///
72    /// # Panics
73    ///
74    /// * If `dim0` is out of bounds for the tensor rank
75    /// * If `dim1` is out of bounds for the tensor rank
76    ///
77    /// # Examples
78    ///
79    /// ```
80    /// use train_station::Tensor;
81    ///
82    /// // Basic 2D transpose
83    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
84    /// let transposed = tensor.transpose(0, 1);
85    /// assert_eq!(transposed.shape().dims, vec![3, 2]);
86    /// assert_eq!(transposed.get(&[0, 0]), 1.0);
87    /// assert_eq!(transposed.get(&[0, 1]), 4.0);
88    /// assert_eq!(transposed.get(&[1, 0]), 2.0);
89    /// assert_eq!(transposed.get(&[1, 1]), 5.0);
90    /// assert_eq!(transposed.get(&[2, 0]), 3.0);
91    /// assert_eq!(transposed.get(&[2, 1]), 6.0);
92    /// ```
93    ///
94    /// ```
95    /// use train_station::Tensor;
96    ///
97    /// // 3D tensor transpose
98    /// let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
99    /// let tensor = Tensor::from_slice(&data, vec![2, 3, 4]).unwrap();
100    /// let transposed = tensor.transpose(0, 1);
101    /// assert_eq!(transposed.shape().dims, vec![3, 2, 4]);
102    /// ```
103    ///
104    /// ```
105    /// use train_station::Tensor;
106    ///
107    /// // Transpose with gradient tracking
108    /// let mut tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
109    /// tensor.set_requires_grad(true);
110    ///
111    /// let transposed = tensor.transpose(0, 1);
112    /// assert!(transposed.requires_grad());
113    /// assert_eq!(transposed.shape().dims, vec![2, 2]);
114    /// ```
115    ///
116    /// ```
117    /// use train_station::Tensor;
118    ///
119    /// // Transpose same dimension (no change)
120    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
121    /// let result = tensor.transpose(1, 1);
122    /// assert_eq!(result.shape().dims, tensor.shape().dims);
123    /// assert_eq!(result.get(&[0, 0]), tensor.get(&[0, 0]));
124    /// ```
125    ///
126    /// ```
127    /// use train_station::Tensor;
128    ///
129    /// // Transpose is its own inverse
130    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
131    /// let transposed = tensor.transpose(0, 1);
132    /// let double_transposed = transposed.transpose(0, 1);
133    /// assert_eq!(double_transposed.shape().dims, tensor.shape().dims);
134    /// assert_eq!(double_transposed.get(&[0, 0]), tensor.get(&[0, 0]));
135    /// ```
136    ///
137    /// # Performance
138    ///
139    /// - **Contiguous tensors**: O(1) time complexity, returns a view
140    /// - **Non-contiguous tensors**: O(n) time complexity with data copying
141    /// - **Memory usage**: No additional allocation for view operations
142    /// - **Gradient tracking**: Preserves gradient requirements and tracking
143    ///
144    /// # Relationship to Other Operations
145    ///
146    /// This operation is related to other tensor transformations:
147    /// - `t()` - Convenience method for matrix transpose (last two dimensions)
148    /// - `permute()` - More general dimension reordering operation
149    /// - `reshape()` - Changes shape without changing dimension order
150    ///
151    /// # Memory Layout
152    ///
153    /// For contiguous tensors, transpose returns a view with modified strides,
154    /// making the tensor non-contiguous. For non-contiguous tensors or complex
155    /// cases, data is copied to ensure correct transposition.
156    #[track_caller]
157    pub fn transpose(&self, dim0: usize, dim1: usize) -> Tensor {
158        assert!(
159            dim0 < self.shape().rank(),
160            "dim0 {} out of bounds for tensor with rank {}",
161            dim0,
162            self.shape().rank()
163        );
164        assert!(
165            dim1 < self.shape().rank(),
166            "dim1 {} out of bounds for tensor with rank {}",
167            dim1,
168            self.shape().rank()
169        );
170
171        // If same dimension, return a clone
172        if dim0 == dim1 {
173            return self.clone();
174        }
175
176        // Create new dimensions and strides by swapping
177        let mut new_dims = self.shape().dims.clone();
178        let mut new_strides = self.strides().to_vec();
179
180        new_dims.swap(dim0, dim1);
181        new_strides.swap(dim0, dim1);
182
183        // Create a view-based transpose when possible (creates non-contiguous tensor)
184        let mut result = if self.is_contiguous() && self.can_transpose_as_view(dim0, dim1) {
185            let new_shape = crate::tensor::Shape::as_view(new_dims, new_strides);
186            self.create_view_with_shape(new_shape)
187        } else {
188            // Fallback to copy for complex cases
189            self.transpose_with_copy(new_dims, new_strides, dim0, dim1)
190        };
191
192        // GradTrack: register transpose for backward (transpose is its own inverse)
193        if self.requires_grad() {
194            result.set_requires_grad(true);
195            let grad_fn = crate::gradtrack::grad_fn::GradFn::Transpose {
196                dim0,
197                dim1,
198                input_shape: self.shape().dims.clone(),
199            };
200            result.set_grad_fn(grad_fn.clone());
201            crate::gradtrack::engine::GradEngine::register_operation(
202                result.id(),
203                vec![self.id()],
204                grad_fn,
205            );
206        }
207
208        result
209    }
210
211    /// Matrix transpose (transpose last two dimensions)
212    ///
213    /// Convenience method for the common case of matrix transposition.
214    /// For 2D tensors, this performs a standard matrix transpose.
215    /// For higher-dimensional tensors, this transposes the last two
216    /// dimensions, treating the tensor as a batch of matrices.
217    ///
218    /// This method is equivalent to `transpose(rank-2, rank-1)` where
219    /// `rank` is the number of dimensions in the tensor.
220    ///
221    /// # Returns
222    ///
223    /// A new tensor with the last two dimensions transposed
224    ///
225    /// # Panics
226    ///
227    /// * If the tensor has less than 2 dimensions
228    ///
229    /// # Examples
230    ///
231    /// ```
232    /// use train_station::Tensor;
233    ///
234    /// // 2D matrix transpose
235    /// let matrix = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
236    /// let transposed = matrix.t();
237    /// assert_eq!(transposed.shape().dims, vec![2, 2]);
238    /// assert_eq!(transposed.get(&[0, 0]), 1.0);
239    /// assert_eq!(transposed.get(&[0, 1]), 3.0);
240    /// assert_eq!(transposed.get(&[1, 0]), 2.0);
241    /// assert_eq!(transposed.get(&[1, 1]), 4.0);
242    /// ```
243    ///
244    /// ```
245    /// use train_station::Tensor;
246    ///
247    /// // 3D tensor (batch of matrices)
248    /// let data: Vec<f32> = (0..12).map(|i| i as f32).collect();
249    /// let tensor = Tensor::from_slice(&data, vec![2, 2, 3]).unwrap();
250    /// let transposed = tensor.t();
251    /// assert_eq!(transposed.shape().dims, vec![2, 3, 2]);
252    /// ```
253    ///
254    /// ```
255    /// use train_station::Tensor;
256    ///
257    /// // Matrix transpose with gradient tracking
258    /// let mut matrix = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
259    /// matrix.set_requires_grad(true);
260    ///
261    /// let transposed = matrix.t();
262    /// assert!(transposed.requires_grad());
263    /// assert_eq!(transposed.shape().dims, vec![2, 2]);
264    /// ```
265    ///
266    /// # Performance
267    ///
268    /// - **Time Complexity**: Same as `transpose()` - O(1) for views, O(n) for copies
269    /// - **Memory Usage**: Same as `transpose()` - no allocation for views
270    /// - **Gradient Tracking**: Preserves gradient requirements and tracking
271    ///
272    /// # Relationship to Other Operations
273    ///
274    /// This operation is equivalent to:
275    /// ```rust
276    /// use train_station::Tensor;
277    ///
278    /// let tensor = Tensor::new(vec![2, 3, 4]);
279    /// let rank = tensor.shape().rank();
280    /// let transposed1 = tensor.t();
281    /// let transposed2 = tensor.transpose(rank - 2, rank - 1);
282    /// // transposed1 and transposed2 are identical
283    /// ```
284    #[track_caller]
285    pub fn t(&self) -> Tensor {
286        assert!(
287            self.shape().rank() >= 2,
288            "Matrix transpose requires at least 2 dimensions, got {}",
289            self.shape().rank()
290        );
291        let rank = self.shape().rank();
292        self.transpose(rank - 2, rank - 1)
293    }
294
295    /// Check if transpose can be done as a zero-copy view operation
296    ///
297    /// Determines whether the transpose operation can be performed as a
298    /// zero-copy view by manipulating strides rather than copying data.
299    /// This is possible for contiguous tensors when swapping different dimensions.
300    ///
301    /// # Arguments
302    ///
303    /// * `dim0` - First dimension to swap
304    /// * `dim1` - Second dimension to swap
305    ///
306    /// # Returns
307    ///
308    /// `true` if the transpose can be done as a view (zero-copy), `false` otherwise
309    ///
310    /// # Performance
311    ///
312    /// - **Time Complexity**: O(1) - Simple boolean checks
313    /// - **Memory Usage**: No allocation
314    ///
315    /// # Examples
316    ///
317    /// This method is used internally by the `transpose()` function to
318    /// determine the optimal implementation strategy (view vs copy).
319    fn can_transpose_as_view(&self, dim0: usize, dim1: usize) -> bool {
320        // For contiguous tensors, we can always create a view with different strides
321        // This is safe because we're not modifying the underlying data, just the access pattern
322        self.is_contiguous() && (dim0 != dim1)
323    }
324
325    /// Transpose with data copying when view operation is not possible
326    ///
327    /// Performs transpose by copying data to a new tensor when a view-based
328    /// transpose is not possible or optimal. This method ensures correct
329    /// transposition for all tensor types and memory layouts.
330    ///
331    /// # Arguments
332    ///
333    /// * `new_dims` - The new dimensions after transposition
334    /// * `_new_strides` - The new strides after transposition (unused in copy implementation)
335    /// * `dim0` - First dimension that was swapped
336    /// * `dim1` - Second dimension that was swapped
337    ///
338    /// # Returns
339    ///
340    /// A new tensor with copied and transposed data
341    ///
342    /// # Performance
343    ///
344    /// - **Time Complexity**: O(n) where n is the number of elements
345    /// - **Memory Usage**: Allocates new tensor with same total size
346    /// - **Data Integrity**: Ensures correct transposition for all cases
347    ///
348    /// # Examples
349    ///
350    /// This method is called internally by `transpose()` when view-based
351    /// transposition is not possible, such as for non-contiguous tensors
352    /// or complex memory layouts.
353    fn transpose_with_copy(
354        &self,
355        new_dims: Vec<usize>,
356        _new_strides: Vec<usize>,
357        dim0: usize,
358        dim1: usize,
359    ) -> Tensor {
360        let mut result = Tensor::new(new_dims.clone());
361
362        // Use stride-aware copying that correctly handles arbitrary dimension swaps
363        unsafe {
364            self.transpose_copy_stride_aware(&mut result, dim0, dim1);
365        }
366
367        // Preserve gradient tracking requirement
368        if self.requires_grad() {
369            result.set_requires_grad(true);
370        }
371
372        result
373    }
374
375    /// Stride-aware transpose copy that correctly handles arbitrary dimension swaps
376    ///
377    /// Performs efficient transpose copying using coordinate mapping and
378    /// stride calculations. This method correctly handles transposition
379    /// of any two dimensions in tensors of arbitrary rank and shape.
380    ///
381    /// # Arguments
382    ///
383    /// * `result` - Output tensor to write transposed data
384    /// * `dim0` - First dimension that was swapped
385    /// * `dim1` - Second dimension that was swapped
386    ///
387    /// # Safety
388    ///
389    /// This function uses unsafe pointer arithmetic for performance.
390    /// The caller must ensure:
391    /// * `result` tensor has the correct size and shape
392    /// * `result` tensor is properly allocated and accessible
393    /// * `dim0` and `dim1` are valid dimension indices
394    /// * Source tensor data is valid and accessible
395    ///
396    /// # Performance
397    ///
398    /// - **Time Complexity**: O(n) where n is the number of elements
399    /// - **Memory Access**: Optimized for cache-friendly access patterns
400    /// - **Coordinate Mapping**: Efficient conversion between linear and multi-dimensional indices
401    /// - **Bounds Checking**: Debug assertions for safety in debug builds
402    ///
403    /// # Examples
404    ///
405    /// This method is used internally by `transpose_with_copy()` to perform
406    /// the actual data copying with correct coordinate mapping for arbitrary
407    /// dimension swaps.
408    unsafe fn transpose_copy_stride_aware(&self, result: &mut Tensor, dim0: usize, dim1: usize) {
409        let src_ptr = self.as_ptr();
410        let dst_ptr = result.as_mut_ptr();
411
412        // Iterate through all elements of the result tensor
413        for dst_idx in 0..result.size() {
414            // Convert linear index to multi-dimensional coordinates for result
415            let mut dst_coords = Vec::new();
416            let mut temp_idx = dst_idx;
417
418            for &dim_size in result.shape().dims.iter().rev() {
419                dst_coords.push(temp_idx % dim_size);
420                temp_idx /= dim_size;
421            }
422            dst_coords.reverse();
423
424            // Map result coordinates to source coordinates (reverse the transpose)
425            let mut src_coords = dst_coords.clone();
426            src_coords.swap(dim0, dim1);
427
428            // Calculate source offset using strides
429            let src_offset = self.shape().offset(&src_coords);
430
431            // Bounds check to prevent buffer overruns
432            debug_assert!(
433                src_offset < self.size(),
434                "Source offset {} out of bounds for tensor size {}",
435                src_offset,
436                self.size()
437            );
438            debug_assert!(
439                dst_idx < result.size(),
440                "Destination index {} out of bounds for result size {}",
441                dst_idx,
442                result.size()
443            );
444
445            // Copy element
446            *dst_ptr.add(dst_idx) = *src_ptr.add(src_offset);
447        }
448    }
449}
450
451#[cfg(test)]
452mod tests {
453    use super::*;
454
455    #[test]
456    fn test_transpose_2d_basic() {
457        let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3])
458            .expect("Failed to create tensor");
459        let transposed = tensor.transpose(0, 1);
460
461        assert_eq!(transposed.shape().dims, vec![3, 2]);
462
463        // Verify data layout: original [2,3] -> transposed [3,2]
464        // Original: [[1,2,3], [4,5,6]]
465        // Transposed: [[1,4], [2,5], [3,6]]
466        assert_eq!(transposed.get(&[0, 0]), 1.0);
467        assert_eq!(transposed.get(&[0, 1]), 4.0);
468        assert_eq!(transposed.get(&[1, 0]), 2.0);
469        assert_eq!(transposed.get(&[1, 1]), 5.0);
470        assert_eq!(transposed.get(&[2, 0]), 3.0);
471        assert_eq!(transposed.get(&[2, 1]), 6.0);
472    }
473
474    #[test]
475    fn test_matrix_transpose() {
476        let matrix =
477            Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).expect("Failed to create tensor");
478        let transposed = matrix.t();
479
480        assert_eq!(transposed.shape().dims, vec![2, 2]);
481
482        // Original: [[1,2], [3,4]]
483        // Transposed: [[1,3], [2,4]]
484        assert_eq!(transposed.get(&[0, 0]), 1.0);
485        assert_eq!(transposed.get(&[0, 1]), 3.0);
486        assert_eq!(transposed.get(&[1, 0]), 2.0);
487        assert_eq!(transposed.get(&[1, 1]), 4.0);
488    }
489
490    #[test]
491    fn test_transpose_3d() {
492        let tensor = Tensor::new(vec![2, 3, 4]);
493        let transposed = tensor.transpose(0, 2);
494
495        // Shape changes from [2,3,4] to [4,3,2]
496        assert_eq!(transposed.shape().dims, vec![4, 3, 2]);
497    }
498
499    #[test]
500    fn test_transpose_same_dimension() {
501        let tensor =
502            Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).expect("Failed to create tensor");
503        let result = tensor.transpose(1, 1);
504
505        // Should be identical to original
506        assert_eq!(result.shape().dims, tensor.shape().dims);
507        for i in 0..2 {
508            for j in 0..2 {
509                assert_eq!(result.get(&[i, j]), tensor.get(&[i, j]));
510            }
511        }
512    }
513
514    #[test]
515    fn test_transpose_preserves_gradient_requirement() {
516        let mut tensor = Tensor::new(vec![2, 3]);
517        tensor.set_requires_grad(true);
518        let transposed = tensor.transpose(0, 1);
519
520        assert!(transposed.requires_grad());
521    }
522
523    #[test]
524    #[should_panic(expected = "out of bounds")]
525    fn test_transpose_invalid_dimension() {
526        let tensor = Tensor::new(vec![2, 3]);
527        tensor.transpose(0, 3); // Should panic: dim 3 out of bounds
528    }
529
530    #[test]
531    #[should_panic(expected = "Matrix transpose requires at least 2 dimensions")]
532    fn test_matrix_transpose_1d() {
533        let tensor = Tensor::new(vec![5]);
534        tensor.t(); // Should panic: 1D tensor
535    }
536
537    #[test]
538    fn test_transpose_large_tensor() {
539        // Test with larger tensor to exercise cache-optimized path
540        let tensor = Tensor::new(vec![32, 32]); // 1024 elements
541        let transposed = tensor.transpose(0, 1);
542
543        assert_eq!(transposed.shape().dims, vec![32, 32]);
544    }
545
546    #[test]
547    fn test_transpose_memory_layout() {
548        let tensor = Tensor::new(vec![3, 4]);
549        assert!(tensor.is_contiguous());
550
551        let transposed = tensor.transpose(0, 1);
552        // After transpose, the result should still be valid but may not be contiguous
553        // depending on implementation (view vs copy)
554        assert_eq!(transposed.shape().dims, vec![4, 3]);
555    }
556
557    #[test]
558    fn test_transpose_first_dimensions_3d() {
559        // Test the critical bug fix: transpose dimensions other than last two
560        let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
561        let tensor = Tensor::from_slice(&data, vec![2, 3, 4]).unwrap();
562
563        // Transpose the first two dimensions (not last two)
564        let transposed = tensor.transpose(0, 1);
565
566        // Shape should change from [2,3,4] to [3,2,4]
567        assert_eq!(transposed.shape().dims, vec![3, 2, 4]);
568
569        // Verify data is correctly transposed
570        // Original: tensor[d0][d1][d2] where d0=2, d1=3, d2=4
571        // After transpose(0,1): tensor[d1][d0][d2] where d1=3, d0=2, d2=4
572
573        assert_eq!(transposed.get(&[0, 0, 0]), 0.0); // Maps to original [0,0,0]
574        assert_eq!(transposed.get(&[0, 1, 0]), 12.0); // Maps to original [1,0,0]
575        assert_eq!(transposed.get(&[1, 0, 0]), 4.0); // Maps to original [0,1,0]
576        assert_eq!(transposed.get(&[1, 1, 0]), 16.0); // Maps to original [1,1,0]
577        assert_eq!(transposed.get(&[2, 0, 0]), 8.0); // Maps to original [0,2,0]
578        assert_eq!(transposed.get(&[2, 1, 0]), 20.0); // Maps to original [1,2,0]
579    }
580}