1use slop_alloc::{Backend, CpuBackend};
2
3use crate::{Tensor, TensorViewMut};
4
5pub trait TransposeBackend<T>: Backend {
10 fn transpose_tensor_into(src: &Tensor<T, Self>, dst: TensorViewMut<T, Self>);
12}
13
14impl<T, A: TransposeBackend<T>> Tensor<T, A> {
15 pub fn transpose(&self) -> Tensor<T, A> {
19 let mut sizes = self.sizes().to_vec();
20 let len = sizes.len();
21 assert_eq!(len, 2, "Transpose is only supported for 2D tensors");
22 sizes.swap(len - 1, len - 2);
23 let mut dst = Tensor::with_sizes_in(sizes, self.backend().clone());
24
25 unsafe {
26 dst.assume_init();
27 }
28 A::transpose_tensor_into(self, dst.as_view_mut());
29
30 dst
31 }
32}
33
34impl<T: Copy> TransposeBackend<T> for CpuBackend {
35 fn transpose_tensor_into(src: &Tensor<T, Self>, dst: TensorViewMut<T, Self>) {
36 debug_assert_eq!(src.sizes().len(), 2);
38 debug_assert_eq!(dst.sizes().len(), 2);
39 debug_assert_eq!(src.sizes()[src.sizes().len() - 1], dst.sizes()[dst.sizes().len() - 2]);
40 debug_assert_eq!(src.sizes()[src.sizes().len() - 2], dst.sizes()[dst.sizes().len() - 1]);
41
42 let input_width = src.sizes()[src.sizes().len() - 1];
44 let input_height = src.sizes()[src.sizes().len() - 2];
45
46 transpose::transpose(src.as_buffer(), dst.as_mut_slice(), input_width, input_height);
47 }
48}
49
50#[cfg(test)]
51mod tests {
52 use rand::Rng;
53
54 use super::*;
55
56 #[test]
57 fn test_transpose() {
58 let mut rng = rand::thread_rng();
59
60 for (width, height) in [(2, 3), (5, 10), (100, 500), (1000, 1 << 16)] {
61 let tensor =
62 Tensor::<u32>::from((0..width * height).map(|_| rng.gen()).collect::<Vec<_>>())
63 .reshape([height, width]);
64
65 let transposed = tensor.transpose();
66 assert_eq!(transposed.sizes(), &[width, height]);
67
68 let i = rng.gen_range(0..height);
69 let j = rng.gen_range(0..width);
70 assert_eq!(tensor[[i, j]], transposed[[j, i]]);
71 }
72 }
73}