train_station/tensor/transform/
transpose.rs

1//! Tensor transpose operations
2//!
3//! This module provides tensor transpose functionality that swaps dimensions
4//! of tensors, effectively changing the memory access pattern and logical
5//! arrangement of data. Transposition is a fundamental tensor transformation
6//! operation used in machine learning for matrix operations, preparing data
7//! for specific layer types, and implementing complex tensor manipulations.
8//!
9//! # Operations
10//!
11//! * `transpose()` - Swap two specified dimensions of a tensor
12//! * `t()` - Matrix transpose (swap last two dimensions)
13//!
14//! # Performance Characteristics
15//!
16//! * **Zero-Copy Operations**: Returns a view when possible using stride manipulation
17//! * **Memory Efficient**: Reuses existing tensor data through view operations
18//! * **Cache Optimized**: Uses optimized copying when view operations are not possible
19//! * **Gradient Tracking**: Full GradTrack support for automatic differentiation
20//! * **Shape Transformation**: Changes dimension order while preserving total elements
21//!
22//! # Examples
23//!
24//! ```
25//! use train_station::Tensor;
26//!
27//! // Basic 2D transpose
28//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
29//! let transposed = tensor.transpose(0, 1);
30//! assert_eq!(transposed.shape().dims(), vec![3, 2]);
31//! ```
32//!
33//! ```
34//! use train_station::Tensor;
35//!
36//! // Matrix transpose convenience method
37//! let matrix = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
38//! let transposed = matrix.t();
39//! assert_eq!(transposed.shape().dims(), vec![2, 2]);
40//! ```
41//!
42//! # Gradient Tracking
43//!
44//! The transpose operations support automatic gradient tracking through
45//! the GradTrack system. When `requires_grad` is enabled, the operations
46//! register gradient functions that apply the inverse transpose during
47//! backward passes.
48
49use crate::tensor::core::Tensor;
50
51impl Tensor {
52    /// Transpose two dimensions of the tensor
53    ///
54    /// Swaps two specified dimensions of the tensor, modifying the shape
55    /// and memory access pattern. When possible, this operation returns
56    /// a zero-copy view using stride manipulation. For complex cases or
57    /// non-contiguous tensors, data is copied to ensure correct transposition.
58    ///
59    /// The transpose operation is its own inverse - applying transpose
60    /// twice with the same dimensions returns the original tensor.
61    ///
62    /// # Arguments
63    ///
64    /// * `dim0` - First dimension to swap (must be < tensor rank)
65    /// * `dim1` - Second dimension to swap (must be < tensor rank)
66    ///
67    /// # Returns
68    ///
69    /// A new tensor with the specified dimensions transposed. The total
70    /// number of elements remains unchanged.
71    ///
72    /// # Panics
73    ///
74    /// * If `dim0` is out of bounds for the tensor rank
75    /// * If `dim1` is out of bounds for the tensor rank
76    ///
77    /// # Examples
78    ///
79    /// ```
80    /// use train_station::Tensor;
81    ///
82    /// // Basic 2D transpose
83    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
84    /// let transposed = tensor.transpose(0, 1);
85    /// assert_eq!(transposed.shape().dims(), vec![3, 2]);
86    /// assert_eq!(transposed.get(&[0, 0]), 1.0);
87    /// assert_eq!(transposed.get(&[0, 1]), 4.0);
88    /// assert_eq!(transposed.get(&[1, 0]), 2.0);
89    /// assert_eq!(transposed.get(&[1, 1]), 5.0);
90    /// assert_eq!(transposed.get(&[2, 0]), 3.0);
91    /// assert_eq!(transposed.get(&[2, 1]), 6.0);
92    /// ```
93    ///
94    /// ```
95    /// use train_station::Tensor;
96    ///
97    /// // 3D tensor transpose
98    /// let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
99    /// let tensor = Tensor::from_slice(&data, vec![2, 3, 4]).unwrap();
100    /// let transposed = tensor.transpose(0, 1);
101    /// assert_eq!(transposed.shape().dims(), vec![3, 2, 4]);
102    /// ```
103    ///
104    /// ```
105    /// use train_station::Tensor;
106    ///
107    /// // Transpose with gradient tracking
108    /// let mut tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
109    /// tensor.set_requires_grad(true);
110    ///
111    /// let transposed = tensor.transpose(0, 1);
112    /// assert!(transposed.requires_grad());
113    /// assert_eq!(transposed.shape().dims(), vec![2, 2]);
114    /// ```
115    ///
116    /// ```
117    /// use train_station::Tensor;
118    ///
119    /// // Transpose same dimension (no change)
120    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
121    /// let result = tensor.transpose(1, 1);
122    /// assert_eq!(result.shape().dims(), tensor.shape().dims());
123    /// assert_eq!(result.get(&[0, 0]), tensor.get(&[0, 0]));
124    /// ```
125    ///
126    /// ```
127    /// use train_station::Tensor;
128    ///
129    /// // Transpose is its own inverse
130    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
131    /// let transposed = tensor.transpose(0, 1);
132    /// let double_transposed = transposed.transpose(0, 1);
133    /// assert_eq!(double_transposed.shape().dims(), tensor.shape().dims());
134    /// assert_eq!(double_transposed.get(&[0, 0]), tensor.get(&[0, 0]));
135    /// ```
136    ///
137    /// # Performance
138    ///
139    /// - **Contiguous tensors**: O(1) time complexity, returns a view
140    /// - **Non-contiguous tensors**: O(n) time complexity with data copying
141    /// - **Memory usage**: No additional allocation for view operations
142    /// - **Gradient tracking**: Preserves gradient requirements and tracking
143    ///
144    /// # Relationship to Other Operations
145    ///
146    /// This operation is related to other tensor transformations:
147    /// - `t()` - Convenience method for matrix transpose (last two dimensions)
148    /// - `permute()` - More general dimension reordering operation
149    /// - `reshape()` - Changes shape without changing dimension order
150    ///
151    /// # Memory Layout
152    ///
153    /// For contiguous tensors, transpose returns a view with modified strides,
154    /// making the tensor non-contiguous. For non-contiguous tensors or complex
155    /// cases, data is copied to ensure correct transposition.
156    #[track_caller]
157    pub fn transpose(&self, dim0: usize, dim1: usize) -> Tensor {
158        assert!(
159            dim0 < self.shape().rank(),
160            "dim0 {} out of bounds for tensor with rank {}",
161            dim0,
162            self.shape().rank()
163        );
164        assert!(
165            dim1 < self.shape().rank(),
166            "dim1 {} out of bounds for tensor with rank {}",
167            dim1,
168            self.shape().rank()
169        );
170
171        // If same dimension, return a clone
172        if dim0 == dim1 {
173            return self.clone();
174        }
175
176        // Build permutation and delegate to core view
177        let rank = self.shape().rank();
178        let mut perm: Vec<usize> = (0..rank).collect();
179        perm.swap(dim0, dim1);
180        let mut result = match crate::tensor::core::view::transpose_view(self, &perm) {
181            Ok(v) => v,
182            Err(e) => panic!("transpose view error: {:?}", e),
183        };
184
185        // GradTrack: register transpose for backward (transpose is its own inverse)
186        if self.requires_grad() {
187            result.set_requires_grad(true);
188            let grad_fn = crate::gradtrack::grad_fn::GradFn::Transpose {
189                dim0,
190                dim1,
191                input_shape: self.shape().dims().to_vec(),
192            };
193            result.set_grad_fn(grad_fn.clone());
194            crate::gradtrack::engine::GradEngine::register_operation(
195                result.id(),
196                vec![self.id()],
197                grad_fn,
198            );
199        }
200
201        result
202    }
203
204    /// Matrix transpose (transpose last two dimensions)
205    ///
206    /// Convenience method for the common case of matrix transposition.
207    /// For 2D tensors, this performs a standard matrix transpose.
208    /// For higher-dimensional tensors, this transposes the last two
209    /// dimensions, treating the tensor as a batch of matrices.
210    ///
211    /// This method is equivalent to `transpose(rank-2, rank-1)` where
212    /// `rank` is the number of dimensions in the tensor.
213    ///
214    /// # Returns
215    ///
216    /// A new tensor with the last two dimensions transposed
217    ///
218    /// # Panics
219    ///
220    /// * If the tensor has less than 2 dimensions
221    ///
222    /// # Examples
223    ///
224    /// ```
225    /// use train_station::Tensor;
226    ///
227    /// // 2D matrix transpose
228    /// let matrix = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
229    /// let transposed = matrix.t();
230    /// assert_eq!(transposed.shape().dims(), vec![2, 2]);
231    /// assert_eq!(transposed.get(&[0, 0]), 1.0);
232    /// assert_eq!(transposed.get(&[0, 1]), 3.0);
233    /// assert_eq!(transposed.get(&[1, 0]), 2.0);
234    /// assert_eq!(transposed.get(&[1, 1]), 4.0);
235    /// ```
236    ///
237    /// ```
238    /// use train_station::Tensor;
239    ///
240    /// // 3D tensor (batch of matrices)
241    /// let data: Vec<f32> = (0..12).map(|i| i as f32).collect();
242    /// let tensor = Tensor::from_slice(&data, vec![2, 2, 3]).unwrap();
243    /// let transposed = tensor.t();
244    /// assert_eq!(transposed.shape().dims(), vec![2, 3, 2]);
245    /// ```
246    ///
247    /// ```
248    /// use train_station::Tensor;
249    ///
250    /// // Matrix transpose with gradient tracking
251    /// let mut matrix = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
252    /// matrix.set_requires_grad(true);
253    ///
254    /// let transposed = matrix.t();
255    /// assert!(transposed.requires_grad());
256    /// assert_eq!(transposed.shape().dims(), vec![2, 2]);
257    /// ```
258    ///
259    /// # Performance
260    ///
261    /// - **Time Complexity**: Same as `transpose()` - O(1) for views, O(n) for copies
262    /// - **Memory Usage**: Same as `transpose()` - no allocation for views
263    /// - **Gradient Tracking**: Preserves gradient requirements and tracking
264    ///
265    /// # Relationship to Other Operations
266    ///
267    /// This operation is equivalent to:
268    /// ```rust
269    /// use train_station::Tensor;
270    ///
271    /// let tensor = Tensor::new(vec![2, 3, 4]);
272    /// let rank = tensor.shape().rank();
273    /// let transposed1 = tensor.t();
274    /// let transposed2 = tensor.transpose(rank - 2, rank - 1);
275    /// // transposed1 and transposed2 are identical
276    /// ```
277    #[track_caller]
278    pub fn t(&self) -> Tensor {
279        assert!(
280            self.shape().rank() >= 2,
281            "Matrix transpose requires at least 2 dimensions, got {}",
282            self.shape().rank()
283        );
284        let rank = self.shape().rank();
285        self.transpose(rank - 2, rank - 1)
286    }
287}
288
289#[cfg(test)]
290mod tests {
291    use super::*;
292
293    #[test]
294    fn test_transpose_2d_basic() {
295        let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3])
296            .expect("Failed to create tensor");
297        let transposed = tensor.transpose(0, 1);
298
299        assert_eq!(transposed.shape().dims(), vec![3, 2]);
300
301        // Verify data layout: original [2,3] -> transposed [3,2]
302        // Original: [[1,2,3], [4,5,6]]
303        // Transposed: [[1,4], [2,5], [3,6]]
304        assert_eq!(transposed.get(&[0, 0]), 1.0);
305        assert_eq!(transposed.get(&[0, 1]), 4.0);
306        assert_eq!(transposed.get(&[1, 0]), 2.0);
307        assert_eq!(transposed.get(&[1, 1]), 5.0);
308        assert_eq!(transposed.get(&[2, 0]), 3.0);
309        assert_eq!(transposed.get(&[2, 1]), 6.0);
310    }
311
312    #[test]
313    fn test_matrix_transpose() {
314        let matrix =
315            Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).expect("Failed to create tensor");
316        let transposed = matrix.t();
317
318        assert_eq!(transposed.shape().dims(), vec![2, 2]);
319
320        // Original: [[1,2], [3,4]]
321        // Transposed: [[1,3], [2,4]]
322        assert_eq!(transposed.get(&[0, 0]), 1.0);
323        assert_eq!(transposed.get(&[0, 1]), 3.0);
324        assert_eq!(transposed.get(&[1, 0]), 2.0);
325        assert_eq!(transposed.get(&[1, 1]), 4.0);
326    }
327
328    #[test]
329    fn test_transpose_3d() {
330        let tensor = Tensor::new(vec![2, 3, 4]);
331        let transposed = tensor.transpose(0, 2);
332
333        // Shape changes from [2,3,4] to [4,3,2]
334        assert_eq!(transposed.shape().dims(), vec![4, 3, 2]);
335    }
336
337    #[test]
338    fn test_transpose_same_dimension() {
339        let tensor =
340            Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).expect("Failed to create tensor");
341        let result = tensor.transpose(1, 1);
342
343        // Should be identical to original
344        assert_eq!(result.shape().dims(), tensor.shape().dims());
345        for i in 0..2 {
346            for j in 0..2 {
347                assert_eq!(result.get(&[i, j]), tensor.get(&[i, j]));
348            }
349        }
350    }
351
352    #[test]
353    fn test_transpose_preserves_gradient_requirement() {
354        let mut tensor = Tensor::new(vec![2, 3]);
355        tensor.set_requires_grad(true);
356        let transposed = tensor.transpose(0, 1);
357
358        assert!(transposed.requires_grad());
359    }
360
361    #[test]
362    #[should_panic(expected = "out of bounds")]
363    fn test_transpose_invalid_dimension() {
364        let tensor = Tensor::new(vec![2, 3]);
365        tensor.transpose(0, 3); // Should panic: dim 3 out of bounds
366    }
367
368    #[test]
369    #[should_panic(expected = "Matrix transpose requires at least 2 dimensions")]
370    fn test_matrix_transpose_1d() {
371        let tensor = Tensor::new(vec![5]);
372        tensor.t(); // Should panic: 1D tensor
373    }
374
375    #[test]
376    fn test_transpose_large_tensor() {
377        // Test with larger tensor to exercise cache-optimized path
378        let tensor = Tensor::new(vec![32, 32]); // 1024 elements
379        let transposed = tensor.transpose(0, 1);
380
381        assert_eq!(transposed.shape().dims(), vec![32, 32]);
382    }
383
384    #[test]
385    fn test_transpose_memory_layout() {
386        let tensor = Tensor::new(vec![3, 4]);
387        assert!(tensor.is_contiguous());
388
389        let transposed = tensor.transpose(0, 1);
390        // After transpose, the result should still be valid but may not be contiguous
391        // depending on implementation (view vs copy)
392        assert_eq!(transposed.shape().dims(), vec![4, 3]);
393    }
394
395    #[test]
396    fn test_transpose_first_dimensions_3d() {
397        // Test the critical bug fix: transpose dimensions other than last two
398        let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
399        let tensor = Tensor::from_slice(&data, vec![2, 3, 4]).unwrap();
400
401        // Transpose the first two dimensions (not last two)
402        let transposed = tensor.transpose(0, 1);
403
404        // Shape should change from [2,3,4] to [3,2,4]
405        assert_eq!(transposed.shape().dims(), vec![3, 2, 4]);
406
407        // Verify data is correctly transposed
408        // Original: tensor[d0][d1][d2] where d0=2, d1=3, d2=4
409        // After transpose(0,1): tensor[d1][d0][d2] where d1=3, d0=2, d2=4
410
411        assert_eq!(transposed.get(&[0, 0, 0]), 0.0); // Maps to original [0,0,0]
412        assert_eq!(transposed.get(&[0, 1, 0]), 12.0); // Maps to original [1,0,0]
413        assert_eq!(transposed.get(&[1, 0, 0]), 4.0); // Maps to original [0,1,0]
414        assert_eq!(transposed.get(&[1, 1, 0]), 16.0); // Maps to original [1,1,0]
415        assert_eq!(transposed.get(&[2, 0, 0]), 8.0); // Maps to original [0,2,0]
416        assert_eq!(transposed.get(&[2, 1, 0]), 20.0); // Maps to original [1,2,0]
417    }
418}