zenu_matrix/operation/
transpose.rs

1use crate::{
2    device::Device,
3    dim::{DimDyn, DimTrait},
4    matrix::{Matrix, Owned, Repr},
5};
6
7impl<R: Repr, D: Device> Matrix<R, DimDyn, D> {
8    pub fn transpose(&mut self) {
9        let shape_stride = self.shape_stride();
10        let transposed = shape_stride.transpose();
11        self.update_shape_stride(transposed);
12    }
13
14    pub fn transpose_by_index(&mut self, index: &[usize]) {
15        let shape_stride = self.shape_stride();
16        let transposed = shape_stride.transpose_by_index(index);
17        self.update_shape_stride(transposed);
18    }
19
20    #[must_use]
21    pub fn transpose_by_index_new_matrix(
22        &self,
23        index: &[usize],
24    ) -> Matrix<Owned<R::Item>, DimDyn, D> {
25        let mut ref_mat = self.to_ref();
26        ref_mat.transpose_by_index(index);
27        ref_mat.to_default_stride()
28    }
29
30    #[expect(clippy::missing_panics_doc)]
31    pub fn transpose_swap_index(&mut self, a: usize, b: usize) {
32        assert!(a != b, "Index must be different");
33        if a < b {
34            return self.transpose_swap_index(b, a);
35        }
36        assert!(a < self.shape().len(), "Index out of range");
37        assert!(b < self.shape().len(), "Index out of range");
38
39        let shape_stride = self.shape_stride().swap_index(a, b);
40        self.update_shape_stride(shape_stride);
41    }
42
43    #[expect(clippy::missing_panics_doc)]
44    #[must_use]
45    pub fn transpose_swap_index_new_matrix(
46        &self,
47        a: usize,
48        b: usize,
49    ) -> Matrix<Owned<R::Item>, DimDyn, D> {
50        assert!(a != b, "Index must be different");
51        if a < b {
52            return self.transpose_swap_index_new_matrix(b, a);
53        }
54        let mut ref_mat = self.to_ref();
55        ref_mat.transpose_swap_index(a, b);
56        ref_mat.to_default_stride()
57    }
58}
59
60#[expect(clippy::float_cmp)]
61#[cfg(test)]
62mod transpose {
63    use crate::{
64        device::Device,
65        dim::DimDyn,
66        matrix::{Matrix, Owned},
67    };
68
69    // #[test]
70    fn transpose_2d<D: Device>() {
71        let mut a: Matrix<Owned<f32>, DimDyn, D> =
72            Matrix::from_vec(vec![1., 2., 3., 4., 5., 6.], [2, 3]);
73        a.transpose();
74        assert_eq!(a.index_item([0, 0]), 1.);
75        assert_eq!(a.index_item([0, 1]), 4.);
76        assert_eq!(a.index_item([1, 0]), 2.);
77        assert_eq!(a.index_item([1, 1]), 5.);
78        assert_eq!(a.index_item([2, 0]), 3.);
79        assert_eq!(a.index_item([2, 1]), 6.);
80    }
81    #[test]
82    fn transpose_2d_cpu() {
83        transpose_2d::<crate::device::cpu::Cpu>();
84    }
85    #[cfg(feature = "nvidia")]
86    #[test]
87    fn transpose_2d_cuda() {
88        transpose_2d::<crate::device::nvidia::Nvidia>();
89    }
90}
91
92#[expect(clippy::cast_precision_loss)]
93#[cfg(test)]
94mod transpose_inplace {
95    use crate::{
96        device::Device,
97        dim::DimDyn,
98        matrix::{Matrix, Owned},
99    };
100
101    fn inplace_transpose_4d<D: Device>() {
102        let mut input = vec![];
103        for i in 0..3 {
104            for j in 0..4 {
105                for k in 0..5 {
106                    for l in 0..6 {
107                        input.push((i * 1000 + j * 100 + k * 10 + l) as f32);
108                    }
109                }
110            }
111        }
112        let input: Matrix<Owned<f32>, DimDyn, D> = Matrix::from_vec(input, [3, 4, 5, 6]);
113        let output = input.transpose_by_index_new_matrix(&[1, 0, 2, 3]);
114        let ans = vec![
115            0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 20.0, 21.0, 22.0,
116            23.0, 24.0, 25.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0, 40.0, 41.0, 42.0, 43.0, 44.0,
117            45.0, 1000.0, 1001.0, 1002.0, 1003.0, 1004.0, 1005.0, 1010.0, 1011.0, 1012.0, 1013.0,
118            1014.0, 1015.0, 1020.0, 1021.0, 1022.0, 1023.0, 1024.0, 1025.0, 1030.0, 1031.0, 1032.0,
119            1033.0, 1034.0, 1035.0, 1040.0, 1041.0, 1042.0, 1043.0, 1044.0, 1045.0, 2000.0, 2001.0,
120            2002.0, 2003.0, 2004.0, 2005.0, 2010.0, 2011.0, 2012.0, 2013.0, 2014.0, 2015.0, 2020.0,
121            2021.0, 2022.0, 2023.0, 2024.0, 2025.0, 2030.0, 2031.0, 2032.0, 2033.0, 2034.0, 2035.0,
122            2040.0, 2041.0, 2042.0, 2043.0, 2044.0, 2045.0, 100.0, 101.0, 102.0, 103.0, 104.0,
123            105.0, 110.0, 111.0, 112.0, 113.0, 114.0, 115.0, 120.0, 121.0, 122.0, 123.0, 124.0,
124            125.0, 130.0, 131.0, 132.0, 133.0, 134.0, 135.0, 140.0, 141.0, 142.0, 143.0, 144.0,
125            145.0, 1100.0, 1101.0, 1102.0, 1103.0, 1104.0, 1105.0, 1110.0, 1111.0, 1112.0, 1113.0,
126            1114.0, 1115.0, 1120.0, 1121.0, 1122.0, 1123.0, 1124.0, 1125.0, 1130.0, 1131.0, 1132.0,
127            1133.0, 1134.0, 1135.0, 1140.0, 1141.0, 1142.0, 1143.0, 1144.0, 1145.0, 2100.0, 2101.0,
128            2102.0, 2103.0, 2104.0, 2105.0, 2110.0, 2111.0, 2112.0, 2113.0, 2114.0, 2115.0, 2120.0,
129            2121.0, 2122.0, 2123.0, 2124.0, 2125.0, 2130.0, 2131.0, 2132.0, 2133.0, 2134.0, 2135.0,
130            2140.0, 2141.0, 2142.0, 2143.0, 2144.0, 2145.0, 200.0, 201.0, 202.0, 203.0, 204.0,
131            205.0, 210.0, 211.0, 212.0, 213.0, 214.0, 215.0, 220.0, 221.0, 222.0, 223.0, 224.0,
132            225.0, 230.0, 231.0, 232.0, 233.0, 234.0, 235.0, 240.0, 241.0, 242.0, 243.0, 244.0,
133            245.0, 1200.0, 1201.0, 1202.0, 1203.0, 1204.0, 1205.0, 1210.0, 1211.0, 1212.0, 1213.0,
134            1214.0, 1215.0, 1220.0, 1221.0, 1222.0, 1223.0, 1224.0, 1225.0, 1230.0, 1231.0, 1232.0,
135            1233.0, 1234.0, 1235.0, 1240.0, 1241.0, 1242.0, 1243.0, 1244.0, 1245.0, 2200.0, 2201.0,
136            2202.0, 2203.0, 2204.0, 2205.0, 2210.0, 2211.0, 2212.0, 2213.0, 2214.0, 2215.0, 2220.0,
137            2221.0, 2222.0, 2223.0, 2224.0, 2225.0, 2230.0, 2231.0, 2232.0, 2233.0, 2234.0, 2235.0,
138            2240.0, 2241.0, 2242.0, 2243.0, 2244.0, 2245.0, 300.0, 301.0, 302.0, 303.0, 304.0,
139            305.0, 310.0, 311.0, 312.0, 313.0, 314.0, 315.0, 320.0, 321.0, 322.0, 323.0, 324.0,
140            325.0, 330.0, 331.0, 332.0, 333.0, 334.0, 335.0, 340.0, 341.0, 342.0, 343.0, 344.0,
141            345.0, 1300.0, 1301.0, 1302.0, 1303.0, 1304.0, 1305.0, 1310.0, 1311.0, 1312.0, 1313.0,
142            1314.0, 1315.0, 1320.0, 1321.0, 1322.0, 1323.0, 1324.0, 1325.0, 1330.0, 1331.0, 1332.0,
143            1333.0, 1334.0, 1335.0, 1340.0, 1341.0, 1342.0, 1343.0, 1344.0, 1345.0, 2300.0, 2301.0,
144            2302.0, 2303.0, 2304.0, 2305.0, 2310.0, 2311.0, 2312.0, 2313.0, 2314.0, 2315.0, 2320.0,
145            2321.0, 2322.0, 2323.0, 2324.0, 2325.0, 2330.0, 2331.0, 2332.0, 2333.0, 2334.0, 2335.0,
146            2340.0, 2341.0, 2342.0, 2343.0, 2344.0, 2345.0,
147        ];
148        let ans: Matrix<Owned<f32>, DimDyn, D> = Matrix::from_vec(ans, [4, 3, 5, 6]);
149        assert!((output - ans).asum() < 1e-6);
150    }
151    #[test]
152    fn inplace_transpose_4d_cpu() {
153        inplace_transpose_4d::<crate::device::cpu::Cpu>();
154    }
155    #[cfg(feature = "nvidia")]
156    #[test]
157    fn inplace_transpose_4d_cuda() {
158        inplace_transpose_4d::<crate::device::nvidia::Nvidia>();
159    }
160
161    fn swap_axis<D: Device>() {
162        let input: Matrix<Owned<f32>, DimDyn, D> =
163            Matrix::from_vec(vec![1., 2., 3., 4., 5., 6.], [2, 3]);
164        let output = input.transpose_swap_index_new_matrix(0, 1);
165        let ans: Matrix<Owned<f32>, DimDyn, D> =
166            Matrix::from_vec(vec![1., 4., 2., 5., 3., 6.], [3, 2]);
167        assert!((output - ans).asum() < 1e-6);
168    }
169    #[test]
170    fn swap_axis_cpu() {
171        swap_axis::<crate::device::cpu::Cpu>();
172    }
173    #[cfg(feature = "nvidia")]
174    #[test]
175    fn swap_axis_nvidia() {
176        swap_axis::<crate::device::nvidia::Nvidia>();
177    }
178
179    // #[test]
180    fn swap_axis_3d<D: Device>() {
181        // 2, 3, 4
182        let input: Matrix<Owned<f32>, DimDyn, D> = Matrix::from_vec(
183            vec![
184                1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18.,
185                19., 20., 21., 22., 23., 24.,
186            ],
187            [2, 3, 4],
188        );
189        let output = input.transpose_swap_index_new_matrix(0, 1);
190        let ans: Matrix<Owned<f32>, DimDyn, D> = Matrix::from_vec(
191            vec![
192                1., 2., 3., 4., 13., 14., 15., 16., 5., 6., 7., 8., 17., 18., 19., 20., 9., 10.,
193                11., 12., 21., 22., 23., 24.,
194            ],
195            [3, 2, 4],
196        );
197        assert!((output - ans).asum() < 1e-6);
198    }
199    #[test]
200    fn swap_axis_3d_cpu() {
201        swap_axis_3d::<crate::device::cpu::Cpu>();
202    }
203    #[cfg(feature = "nvidia")]
204    #[test]
205    fn swap_axis_3d_cuda() {
206        swap_axis_3d::<crate::device::nvidia::Nvidia>();
207    }
208}