train_station/tensor/transform/
transpose.rs

1//! Tensor transpose operations
2//!
3//! This module provides tensor transpose functionality that swaps dimensions
4//! of tensors, effectively changing the memory access pattern and logical
5//! arrangement of data. Transposition is a fundamental tensor transformation
6//! operation used in machine learning for matrix operations, preparing data
7//! for specific layer types, and implementing complex tensor manipulations.
8//!
9//! # Operations
10//!
11//! * `transpose()` - Swap two specified dimensions of a tensor
12//! * `t()` - Matrix transpose (swap last two dimensions)
13//!
14//! # Performance Characteristics
15//!
16//! * **Zero-Copy Operations**: Returns a view when possible using stride manipulation
17//! * **Memory Efficient**: Reuses existing tensor data through view operations
18//! * **Cache Optimized**: Uses optimized copying when view operations are not possible
19//! * **Gradient Tracking**: Full GradTrack support for automatic differentiation
20//! * **Shape Transformation**: Changes dimension order while preserving total elements
21//!
22//! # Examples
23//!
24//! ```
25//! use train_station::Tensor;
26//!
27//! // Basic 2D transpose
28//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
29//! let transposed = tensor.transpose(0, 1);
30//! assert_eq!(transposed.shape().dims, vec![3, 2]);
31//! ```
32//!
33//! ```
34//! use train_station::Tensor;
35//!
36//! // Matrix transpose convenience method
37//! let matrix = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
38//! let transposed = matrix.t();
39//! assert_eq!(transposed.shape().dims, vec![2, 2]);
40//! ```
41//!
42//! # Gradient Tracking
43//!
44//! The transpose operations support automatic gradient tracking through
45//! the GradTrack system. When `requires_grad` is enabled, the operations
46//! register gradient functions that apply the inverse transpose during
47//! backward passes.
48
49use crate::tensor::core::Tensor;
50
51impl Tensor {
52    /// Transpose two dimensions of the tensor
53    ///
54    /// Swaps two specified dimensions of the tensor, modifying the shape
55    /// and memory access pattern. When possible, this operation returns
56    /// a zero-copy view using stride manipulation. For complex cases or
57    /// non-contiguous tensors, data is copied to ensure correct transposition.
58    ///
59    /// The transpose operation is its own inverse - applying transpose
60    /// twice with the same dimensions returns the original tensor.
61    ///
62    /// # Arguments
63    ///
64    /// * `dim0` - First dimension to swap (must be < tensor rank)
65    /// * `dim1` - Second dimension to swap (must be < tensor rank)
66    ///
67    /// # Returns
68    ///
69    /// A new tensor with the specified dimensions transposed. The total
70    /// number of elements remains unchanged.
71    ///
72    /// # Panics
73    ///
74    /// * If `dim0` is out of bounds for the tensor rank
75    /// * If `dim1` is out of bounds for the tensor rank
76    ///
77    /// # Examples
78    ///
79    /// ```
80    /// use train_station::Tensor;
81    ///
82    /// // Basic 2D transpose
83    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
84    /// let transposed = tensor.transpose(0, 1);
85    /// assert_eq!(transposed.shape().dims, vec![3, 2]);
86    /// assert_eq!(transposed.get(&[0, 0]), 1.0);
87    /// assert_eq!(transposed.get(&[0, 1]), 4.0);
88    /// assert_eq!(transposed.get(&[1, 0]), 2.0);
89    /// assert_eq!(transposed.get(&[1, 1]), 5.0);
90    /// assert_eq!(transposed.get(&[2, 0]), 3.0);
91    /// assert_eq!(transposed.get(&[2, 1]), 6.0);
92    /// ```
93    ///
94    /// ```
95    /// use train_station::Tensor;
96    ///
97    /// // 3D tensor transpose
98    /// let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
99    /// let tensor = Tensor::from_slice(&data, vec![2, 3, 4]).unwrap();
100    /// let transposed = tensor.transpose(0, 1);
101    /// assert_eq!(transposed.shape().dims, vec![3, 2, 4]);
102    /// ```
103    ///
104    /// ```
105    /// use train_station::Tensor;
106    ///
107    /// // Transpose with gradient tracking
108    /// let mut tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
109    /// tensor.set_requires_grad(true);
110    ///
111    /// let transposed = tensor.transpose(0, 1);
112    /// assert!(transposed.requires_grad());
113    /// assert_eq!(transposed.shape().dims, vec![2, 2]);
114    /// ```
115    ///
116    /// ```
117    /// use train_station::Tensor;
118    ///
119    /// // Transpose same dimension (no change)
120    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
121    /// let result = tensor.transpose(1, 1);
122    /// assert_eq!(result.shape().dims, tensor.shape().dims);
123    /// assert_eq!(result.get(&[0, 0]), tensor.get(&[0, 0]));
124    /// ```
125    ///
126    /// ```
127    /// use train_station::Tensor;
128    ///
129    /// // Transpose is its own inverse
130    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
131    /// let transposed = tensor.transpose(0, 1);
132    /// let double_transposed = transposed.transpose(0, 1);
133    /// assert_eq!(double_transposed.shape().dims, tensor.shape().dims);
134    /// assert_eq!(double_transposed.get(&[0, 0]), tensor.get(&[0, 0]));
135    /// ```
136    ///
137    /// # Performance
138    ///
139    /// - **Contiguous tensors**: O(1) time complexity, returns a view
140    /// - **Non-contiguous tensors**: O(n) time complexity with data copying
141    /// - **Memory usage**: No additional allocation for view operations
142    /// - **Gradient tracking**: Preserves gradient requirements and tracking
143    ///
144    /// # Relationship to Other Operations
145    ///
146    /// This operation is related to other tensor transformations:
147    /// - `t()` - Convenience method for matrix transpose (last two dimensions)
148    /// - `permute()` - More general dimension reordering operation
149    /// - `reshape()` - Changes shape without changing dimension order
150    ///
151    /// # Memory Layout
152    ///
153    /// For contiguous tensors, transpose returns a view with modified strides,
154    /// making the tensor non-contiguous. For non-contiguous tensors or complex
155    /// cases, data is copied to ensure correct transposition.
156    pub fn transpose(&self, dim0: usize, dim1: usize) -> Tensor {
157        assert!(
158            dim0 < self.shape().rank(),
159            "dim0 {} out of bounds for tensor with rank {}",
160            dim0,
161            self.shape().rank()
162        );
163        assert!(
164            dim1 < self.shape().rank(),
165            "dim1 {} out of bounds for tensor with rank {}",
166            dim1,
167            self.shape().rank()
168        );
169
170        // If same dimension, return a clone
171        if dim0 == dim1 {
172            return self.clone();
173        }
174
175        // Create new dimensions and strides by swapping
176        let mut new_dims = self.shape().dims.clone();
177        let mut new_strides = self.strides().to_vec();
178
179        new_dims.swap(dim0, dim1);
180        new_strides.swap(dim0, dim1);
181
182        // Create a view-based transpose when possible (creates non-contiguous tensor)
183        let mut result = if self.is_contiguous() && self.can_transpose_as_view(dim0, dim1) {
184            let new_shape = crate::tensor::Shape::as_view(new_dims, new_strides);
185            self.create_view_with_shape(new_shape)
186        } else {
187            // Fallback to copy for complex cases
188            self.transpose_with_copy(new_dims, new_strides, dim0, dim1)
189        };
190
191        // GradTrack: register transpose for backward (transpose is its own inverse)
192        if self.requires_grad() {
193            result.set_requires_grad(true);
194            let grad_fn = crate::gradtrack::grad_fn::GradFn::Transpose {
195                dim0,
196                dim1,
197                input_shape: self.shape().dims.clone(),
198            };
199            result.set_grad_fn(grad_fn.clone());
200            crate::gradtrack::engine::GradEngine::register_operation(
201                result.id(),
202                vec![self.id()],
203                grad_fn,
204            );
205        }
206
207        result
208    }
209
210    /// Matrix transpose (transpose last two dimensions)
211    ///
212    /// Convenience method for the common case of matrix transposition.
213    /// For 2D tensors, this performs a standard matrix transpose.
214    /// For higher-dimensional tensors, this transposes the last two
215    /// dimensions, treating the tensor as a batch of matrices.
216    ///
217    /// This method is equivalent to `transpose(rank-2, rank-1)` where
218    /// `rank` is the number of dimensions in the tensor.
219    ///
220    /// # Returns
221    ///
222    /// A new tensor with the last two dimensions transposed
223    ///
224    /// # Panics
225    ///
226    /// * If the tensor has less than 2 dimensions
227    ///
228    /// # Examples
229    ///
230    /// ```
231    /// use train_station::Tensor;
232    ///
233    /// // 2D matrix transpose
234    /// let matrix = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
235    /// let transposed = matrix.t();
236    /// assert_eq!(transposed.shape().dims, vec![2, 2]);
237    /// assert_eq!(transposed.get(&[0, 0]), 1.0);
238    /// assert_eq!(transposed.get(&[0, 1]), 3.0);
239    /// assert_eq!(transposed.get(&[1, 0]), 2.0);
240    /// assert_eq!(transposed.get(&[1, 1]), 4.0);
241    /// ```
242    ///
243    /// ```
244    /// use train_station::Tensor;
245    ///
246    /// // 3D tensor (batch of matrices)
247    /// let data: Vec<f32> = (0..12).map(|i| i as f32).collect();
248    /// let tensor = Tensor::from_slice(&data, vec![2, 2, 3]).unwrap();
249    /// let transposed = tensor.t();
250    /// assert_eq!(transposed.shape().dims, vec![2, 3, 2]);
251    /// ```
252    ///
253    /// ```
254    /// use train_station::Tensor;
255    ///
256    /// // Matrix transpose with gradient tracking
257    /// let mut matrix = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
258    /// matrix.set_requires_grad(true);
259    ///
260    /// let transposed = matrix.t();
261    /// assert!(transposed.requires_grad());
262    /// assert_eq!(transposed.shape().dims, vec![2, 2]);
263    /// ```
264    ///
265    /// # Performance
266    ///
267    /// - **Time Complexity**: Same as `transpose()` - O(1) for views, O(n) for copies
268    /// - **Memory Usage**: Same as `transpose()` - no allocation for views
269    /// - **Gradient Tracking**: Preserves gradient requirements and tracking
270    ///
271    /// # Relationship to Other Operations
272    ///
273    /// This operation is equivalent to:
274    /// ```rust
275    /// use train_station::Tensor;
276    ///
277    /// let tensor = Tensor::new(vec![2, 3, 4]);
278    /// let rank = tensor.shape().rank();
279    /// let transposed1 = tensor.t();
280    /// let transposed2 = tensor.transpose(rank - 2, rank - 1);
281    /// // transposed1 and transposed2 are identical
282    /// ```
283    pub fn t(&self) -> Tensor {
284        assert!(
285            self.shape().rank() >= 2,
286            "Matrix transpose requires at least 2 dimensions, got {}",
287            self.shape().rank()
288        );
289        let rank = self.shape().rank();
290        self.transpose(rank - 2, rank - 1)
291    }
292
293    /// Check if transpose can be done as a zero-copy view operation
294    ///
295    /// Determines whether the transpose operation can be performed as a
296    /// zero-copy view by manipulating strides rather than copying data.
297    /// This is possible for contiguous tensors when swapping different dimensions.
298    ///
299    /// # Arguments
300    ///
301    /// * `dim0` - First dimension to swap
302    /// * `dim1` - Second dimension to swap
303    ///
304    /// # Returns
305    ///
306    /// `true` if the transpose can be done as a view (zero-copy), `false` otherwise
307    ///
308    /// # Performance
309    ///
310    /// - **Time Complexity**: O(1) - Simple boolean checks
311    /// - **Memory Usage**: No allocation
312    ///
313    /// # Examples
314    ///
315    /// This method is used internally by the `transpose()` function to
316    /// determine the optimal implementation strategy (view vs copy).
317    fn can_transpose_as_view(&self, dim0: usize, dim1: usize) -> bool {
318        // For contiguous tensors, we can always create a view with different strides
319        // This is safe because we're not modifying the underlying data, just the access pattern
320        self.is_contiguous() && (dim0 != dim1)
321    }
322
323    /// Transpose with data copying when view operation is not possible
324    ///
325    /// Performs transpose by copying data to a new tensor when a view-based
326    /// transpose is not possible or optimal. This method ensures correct
327    /// transposition for all tensor types and memory layouts.
328    ///
329    /// # Arguments
330    ///
331    /// * `new_dims` - The new dimensions after transposition
332    /// * `_new_strides` - The new strides after transposition (unused in copy implementation)
333    /// * `dim0` - First dimension that was swapped
334    /// * `dim1` - Second dimension that was swapped
335    ///
336    /// # Returns
337    ///
338    /// A new tensor with copied and transposed data
339    ///
340    /// # Performance
341    ///
342    /// - **Time Complexity**: O(n) where n is the number of elements
343    /// - **Memory Usage**: Allocates new tensor with same total size
344    /// - **Data Integrity**: Ensures correct transposition for all cases
345    ///
346    /// # Examples
347    ///
348    /// This method is called internally by `transpose()` when view-based
349    /// transposition is not possible, such as for non-contiguous tensors
350    /// or complex memory layouts.
351    fn transpose_with_copy(
352        &self,
353        new_dims: Vec<usize>,
354        _new_strides: Vec<usize>,
355        dim0: usize,
356        dim1: usize,
357    ) -> Tensor {
358        let mut result = Tensor::new(new_dims.clone());
359
360        // Use stride-aware copying that correctly handles arbitrary dimension swaps
361        unsafe {
362            self.transpose_copy_stride_aware(&mut result, dim0, dim1);
363        }
364
365        // Preserve gradient tracking requirement
366        if self.requires_grad() {
367            result.set_requires_grad(true);
368        }
369
370        result
371    }
372
373    /// Stride-aware transpose copy that correctly handles arbitrary dimension swaps
374    ///
375    /// Performs efficient transpose copying using coordinate mapping and
376    /// stride calculations. This method correctly handles transposition
377    /// of any two dimensions in tensors of arbitrary rank and shape.
378    ///
379    /// # Arguments
380    ///
381    /// * `result` - Output tensor to write transposed data
382    /// * `dim0` - First dimension that was swapped
383    /// * `dim1` - Second dimension that was swapped
384    ///
385    /// # Safety
386    ///
387    /// This function uses unsafe pointer arithmetic for performance.
388    /// The caller must ensure:
389    /// * `result` tensor has the correct size and shape
390    /// * `result` tensor is properly allocated and accessible
391    /// * `dim0` and `dim1` are valid dimension indices
392    /// * Source tensor data is valid and accessible
393    ///
394    /// # Performance
395    ///
396    /// - **Time Complexity**: O(n) where n is the number of elements
397    /// - **Memory Access**: Optimized for cache-friendly access patterns
398    /// - **Coordinate Mapping**: Efficient conversion between linear and multi-dimensional indices
399    /// - **Bounds Checking**: Debug assertions for safety in debug builds
400    ///
401    /// # Examples
402    ///
403    /// This method is used internally by `transpose_with_copy()` to perform
404    /// the actual data copying with correct coordinate mapping for arbitrary
405    /// dimension swaps.
406    unsafe fn transpose_copy_stride_aware(&self, result: &mut Tensor, dim0: usize, dim1: usize) {
407        let src_ptr = self.as_ptr();
408        let dst_ptr = result.as_mut_ptr();
409
410        // Iterate through all elements of the result tensor
411        for dst_idx in 0..result.size() {
412            // Convert linear index to multi-dimensional coordinates for result
413            let mut dst_coords = Vec::new();
414            let mut temp_idx = dst_idx;
415
416            for &dim_size in result.shape().dims.iter().rev() {
417                dst_coords.push(temp_idx % dim_size);
418                temp_idx /= dim_size;
419            }
420            dst_coords.reverse();
421
422            // Map result coordinates to source coordinates (reverse the transpose)
423            let mut src_coords = dst_coords.clone();
424            src_coords.swap(dim0, dim1);
425
426            // Calculate source offset using strides
427            let src_offset = self.shape().offset(&src_coords);
428
429            // Bounds check to prevent buffer overruns
430            debug_assert!(
431                src_offset < self.size(),
432                "Source offset {} out of bounds for tensor size {}",
433                src_offset,
434                self.size()
435            );
436            debug_assert!(
437                dst_idx < result.size(),
438                "Destination index {} out of bounds for result size {}",
439                dst_idx,
440                result.size()
441            );
442
443            // Copy element
444            *dst_ptr.add(dst_idx) = *src_ptr.add(src_offset);
445        }
446    }
447}
448
449#[cfg(test)]
450mod tests {
451    use super::*;
452
453    #[test]
454    fn test_transpose_2d_basic() {
455        let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3])
456            .expect("Failed to create tensor");
457        let transposed = tensor.transpose(0, 1);
458
459        assert_eq!(transposed.shape().dims, vec![3, 2]);
460
461        // Verify data layout: original [2,3] -> transposed [3,2]
462        // Original: [[1,2,3], [4,5,6]]
463        // Transposed: [[1,4], [2,5], [3,6]]
464        assert_eq!(transposed.get(&[0, 0]), 1.0);
465        assert_eq!(transposed.get(&[0, 1]), 4.0);
466        assert_eq!(transposed.get(&[1, 0]), 2.0);
467        assert_eq!(transposed.get(&[1, 1]), 5.0);
468        assert_eq!(transposed.get(&[2, 0]), 3.0);
469        assert_eq!(transposed.get(&[2, 1]), 6.0);
470    }
471
472    #[test]
473    fn test_matrix_transpose() {
474        let matrix =
475            Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).expect("Failed to create tensor");
476        let transposed = matrix.t();
477
478        assert_eq!(transposed.shape().dims, vec![2, 2]);
479
480        // Original: [[1,2], [3,4]]
481        // Transposed: [[1,3], [2,4]]
482        assert_eq!(transposed.get(&[0, 0]), 1.0);
483        assert_eq!(transposed.get(&[0, 1]), 3.0);
484        assert_eq!(transposed.get(&[1, 0]), 2.0);
485        assert_eq!(transposed.get(&[1, 1]), 4.0);
486    }
487
488    #[test]
489    fn test_transpose_3d() {
490        let tensor = Tensor::new(vec![2, 3, 4]);
491        let transposed = tensor.transpose(0, 2);
492
493        // Shape changes from [2,3,4] to [4,3,2]
494        assert_eq!(transposed.shape().dims, vec![4, 3, 2]);
495    }
496
497    #[test]
498    fn test_transpose_same_dimension() {
499        let tensor =
500            Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).expect("Failed to create tensor");
501        let result = tensor.transpose(1, 1);
502
503        // Should be identical to original
504        assert_eq!(result.shape().dims, tensor.shape().dims);
505        for i in 0..2 {
506            for j in 0..2 {
507                assert_eq!(result.get(&[i, j]), tensor.get(&[i, j]));
508            }
509        }
510    }
511
512    #[test]
513    fn test_transpose_preserves_gradient_requirement() {
514        let mut tensor = Tensor::new(vec![2, 3]);
515        tensor.set_requires_grad(true);
516        let transposed = tensor.transpose(0, 1);
517
518        assert!(transposed.requires_grad());
519    }
520
521    #[test]
522    #[should_panic(expected = "out of bounds")]
523    fn test_transpose_invalid_dimension() {
524        let tensor = Tensor::new(vec![2, 3]);
525        tensor.transpose(0, 3); // Should panic: dim 3 out of bounds
526    }
527
528    #[test]
529    #[should_panic(expected = "Matrix transpose requires at least 2 dimensions")]
530    fn test_matrix_transpose_1d() {
531        let tensor = Tensor::new(vec![5]);
532        tensor.t(); // Should panic: 1D tensor
533    }
534
535    #[test]
536    fn test_transpose_large_tensor() {
537        // Test with larger tensor to exercise cache-optimized path
538        let tensor = Tensor::new(vec![32, 32]); // 1024 elements
539        let transposed = tensor.transpose(0, 1);
540
541        assert_eq!(transposed.shape().dims, vec![32, 32]);
542    }
543
544    #[test]
545    fn test_transpose_memory_layout() {
546        let tensor = Tensor::new(vec![3, 4]);
547        assert!(tensor.is_contiguous());
548
549        let transposed = tensor.transpose(0, 1);
550        // After transpose, the result should still be valid but may not be contiguous
551        // depending on implementation (view vs copy)
552        assert_eq!(transposed.shape().dims, vec![4, 3]);
553    }
554
555    #[test]
556    fn test_transpose_first_dimensions_3d() {
557        // Test the critical bug fix: transpose dimensions other than last two
558        let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
559        let tensor = Tensor::from_slice(&data, vec![2, 3, 4]).unwrap();
560
561        // Transpose the first two dimensions (not last two)
562        let transposed = tensor.transpose(0, 1);
563
564        // Shape should change from [2,3,4] to [3,2,4]
565        assert_eq!(transposed.shape().dims, vec![3, 2, 4]);
566
567        // Verify data is correctly transposed
568        // Original: tensor[d0][d1][d2] where d0=2, d1=3, d2=4
569        // After transpose(0,1): tensor[d1][d0][d2] where d1=3, d0=2, d2=4
570
571        assert_eq!(transposed.get(&[0, 0, 0]), 0.0); // Maps to original [0,0,0]
572        assert_eq!(transposed.get(&[0, 1, 0]), 12.0); // Maps to original [1,0,0]
573        assert_eq!(transposed.get(&[1, 0, 0]), 4.0); // Maps to original [0,1,0]
574        assert_eq!(transposed.get(&[1, 1, 0]), 16.0); // Maps to original [1,1,0]
575        assert_eq!(transposed.get(&[2, 0, 0]), 8.0); // Maps to original [0,2,0]
576        assert_eq!(transposed.get(&[2, 1, 0]), 20.0); // Maps to original [1,2,0]
577    }
578}