Skip to main content

slop_tensor/
transpose.rs

1use slop_alloc::{Backend, CpuBackend};
2
3use crate::{Tensor, TensorViewMut};
4
5/// A backend that supports the 2D transpose operation.
6///
7/// The operation assumes the input tensor is a 2D tensor with the last two dimensions being the
8/// dimensions to be transposed.
9pub trait TransposeBackend<T>: Backend {
10    /// Transposes the input tensor into the output tensor.
11    fn transpose_tensor_into(src: &Tensor<T, Self>, dst: TensorViewMut<T, Self>);
12}
13
14impl<T, A: TransposeBackend<T>> Tensor<T, A> {
15    /// Returns a new tensor with the last two dimensions transposed.
16    ///
17    /// This function panics if the input tensor is not a 2D tensor.
18    pub fn transpose(&self) -> Tensor<T, A> {
19        let mut sizes = self.sizes().to_vec();
20        let len = sizes.len();
21        assert_eq!(len, 2, "Transpose is only supported for 2D tensors");
22        sizes.swap(len - 1, len - 2);
23        let mut dst = Tensor::with_sizes_in(sizes, self.backend().clone());
24
25        unsafe {
26            dst.assume_init();
27        }
28        A::transpose_tensor_into(self, dst.as_view_mut());
29
30        dst
31    }
32}
33
34impl<T: Copy> TransposeBackend<T> for CpuBackend {
35    fn transpose_tensor_into(src: &Tensor<T, Self>, dst: TensorViewMut<T, Self>) {
36        // Dimension checks.
37        debug_assert_eq!(src.sizes().len(), 2);
38        debug_assert_eq!(dst.sizes().len(), 2);
39        debug_assert_eq!(src.sizes()[src.sizes().len() - 1], dst.sizes()[dst.sizes().len() - 2]);
40        debug_assert_eq!(src.sizes()[src.sizes().len() - 2], dst.sizes()[dst.sizes().len() - 1]);
41
42        // Transpose the data.
43        let input_width = src.sizes()[src.sizes().len() - 1];
44        let input_height = src.sizes()[src.sizes().len() - 2];
45
46        transpose::transpose(src.as_buffer(), dst.as_mut_slice(), input_width, input_height);
47    }
48}
49
50#[cfg(test)]
51mod tests {
52    use rand::Rng;
53
54    use super::*;
55
56    #[test]
57    fn test_transpose() {
58        let mut rng = rand::thread_rng();
59
60        for (width, height) in [(2, 3), (5, 10), (100, 500), (1000, 1 << 16)] {
61            let tensor =
62                Tensor::<u32>::from((0..width * height).map(|_| rng.gen()).collect::<Vec<_>>())
63                    .reshape([height, width]);
64
65            let transposed = tensor.transpose();
66            assert_eq!(transposed.sizes(), &[width, height]);
67
68            let i = rng.gen_range(0..height);
69            let j = rng.gen_range(0..width);
70            assert_eq!(tensor[[i, j]], transposed[[j, i]]);
71        }
72    }
73}