1use crate::{
2 device::Device,
3 dim::{DimDyn, DimTrait},
4 matrix::{Matrix, Owned, Repr},
5};
6
7impl<R: Repr, D: Device> Matrix<R, DimDyn, D> {
8 pub fn transpose(&mut self) {
9 let shape_stride = self.shape_stride();
10 let transposed = shape_stride.transpose();
11 self.update_shape_stride(transposed);
12 }
13
14 pub fn transpose_by_index(&mut self, index: &[usize]) {
15 let shape_stride = self.shape_stride();
16 let transposed = shape_stride.transpose_by_index(index);
17 self.update_shape_stride(transposed);
18 }
19
20 #[must_use]
21 pub fn transpose_by_index_new_matrix(
22 &self,
23 index: &[usize],
24 ) -> Matrix<Owned<R::Item>, DimDyn, D> {
25 let mut ref_mat = self.to_ref();
26 ref_mat.transpose_by_index(index);
27 ref_mat.to_default_stride()
28 }
29
30 #[expect(clippy::missing_panics_doc)]
31 pub fn transpose_swap_index(&mut self, a: usize, b: usize) {
32 assert!(a != b, "Index must be different");
33 if a < b {
34 return self.transpose_swap_index(b, a);
35 }
36 assert!(a < self.shape().len(), "Index out of range");
37 assert!(b < self.shape().len(), "Index out of range");
38
39 let shape_stride = self.shape_stride().swap_index(a, b);
40 self.update_shape_stride(shape_stride);
41 }
42
43 #[expect(clippy::missing_panics_doc)]
44 #[must_use]
45 pub fn transpose_swap_index_new_matrix(
46 &self,
47 a: usize,
48 b: usize,
49 ) -> Matrix<Owned<R::Item>, DimDyn, D> {
50 assert!(a != b, "Index must be different");
51 if a < b {
52 return self.transpose_swap_index_new_matrix(b, a);
53 }
54 let mut ref_mat = self.to_ref();
55 ref_mat.transpose_swap_index(a, b);
56 ref_mat.to_default_stride()
57 }
58}
59
60#[expect(clippy::float_cmp)]
61#[cfg(test)]
62mod transpose {
63 use crate::{
64 device::Device,
65 dim::DimDyn,
66 matrix::{Matrix, Owned},
67 };
68
69 fn transpose_2d<D: Device>() {
71 let mut a: Matrix<Owned<f32>, DimDyn, D> =
72 Matrix::from_vec(vec![1., 2., 3., 4., 5., 6.], [2, 3]);
73 a.transpose();
74 assert_eq!(a.index_item([0, 0]), 1.);
75 assert_eq!(a.index_item([0, 1]), 4.);
76 assert_eq!(a.index_item([1, 0]), 2.);
77 assert_eq!(a.index_item([1, 1]), 5.);
78 assert_eq!(a.index_item([2, 0]), 3.);
79 assert_eq!(a.index_item([2, 1]), 6.);
80 }
81 #[test]
82 fn transpose_2d_cpu() {
83 transpose_2d::<crate::device::cpu::Cpu>();
84 }
85 #[cfg(feature = "nvidia")]
86 #[test]
87 fn transpose_2d_cuda() {
88 transpose_2d::<crate::device::nvidia::Nvidia>();
89 }
90}
91
92#[expect(clippy::cast_precision_loss)]
93#[cfg(test)]
94mod transpose_inplace {
95 use crate::{
96 device::Device,
97 dim::DimDyn,
98 matrix::{Matrix, Owned},
99 };
100
101 fn inplace_transpose_4d<D: Device>() {
102 let mut input = vec![];
103 for i in 0..3 {
104 for j in 0..4 {
105 for k in 0..5 {
106 for l in 0..6 {
107 input.push((i * 1000 + j * 100 + k * 10 + l) as f32);
108 }
109 }
110 }
111 }
112 let input: Matrix<Owned<f32>, DimDyn, D> = Matrix::from_vec(input, [3, 4, 5, 6]);
113 let output = input.transpose_by_index_new_matrix(&[1, 0, 2, 3]);
114 let ans = vec![
115 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 20.0, 21.0, 22.0,
116 23.0, 24.0, 25.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0, 40.0, 41.0, 42.0, 43.0, 44.0,
117 45.0, 1000.0, 1001.0, 1002.0, 1003.0, 1004.0, 1005.0, 1010.0, 1011.0, 1012.0, 1013.0,
118 1014.0, 1015.0, 1020.0, 1021.0, 1022.0, 1023.0, 1024.0, 1025.0, 1030.0, 1031.0, 1032.0,
119 1033.0, 1034.0, 1035.0, 1040.0, 1041.0, 1042.0, 1043.0, 1044.0, 1045.0, 2000.0, 2001.0,
120 2002.0, 2003.0, 2004.0, 2005.0, 2010.0, 2011.0, 2012.0, 2013.0, 2014.0, 2015.0, 2020.0,
121 2021.0, 2022.0, 2023.0, 2024.0, 2025.0, 2030.0, 2031.0, 2032.0, 2033.0, 2034.0, 2035.0,
122 2040.0, 2041.0, 2042.0, 2043.0, 2044.0, 2045.0, 100.0, 101.0, 102.0, 103.0, 104.0,
123 105.0, 110.0, 111.0, 112.0, 113.0, 114.0, 115.0, 120.0, 121.0, 122.0, 123.0, 124.0,
124 125.0, 130.0, 131.0, 132.0, 133.0, 134.0, 135.0, 140.0, 141.0, 142.0, 143.0, 144.0,
125 145.0, 1100.0, 1101.0, 1102.0, 1103.0, 1104.0, 1105.0, 1110.0, 1111.0, 1112.0, 1113.0,
126 1114.0, 1115.0, 1120.0, 1121.0, 1122.0, 1123.0, 1124.0, 1125.0, 1130.0, 1131.0, 1132.0,
127 1133.0, 1134.0, 1135.0, 1140.0, 1141.0, 1142.0, 1143.0, 1144.0, 1145.0, 2100.0, 2101.0,
128 2102.0, 2103.0, 2104.0, 2105.0, 2110.0, 2111.0, 2112.0, 2113.0, 2114.0, 2115.0, 2120.0,
129 2121.0, 2122.0, 2123.0, 2124.0, 2125.0, 2130.0, 2131.0, 2132.0, 2133.0, 2134.0, 2135.0,
130 2140.0, 2141.0, 2142.0, 2143.0, 2144.0, 2145.0, 200.0, 201.0, 202.0, 203.0, 204.0,
131 205.0, 210.0, 211.0, 212.0, 213.0, 214.0, 215.0, 220.0, 221.0, 222.0, 223.0, 224.0,
132 225.0, 230.0, 231.0, 232.0, 233.0, 234.0, 235.0, 240.0, 241.0, 242.0, 243.0, 244.0,
133 245.0, 1200.0, 1201.0, 1202.0, 1203.0, 1204.0, 1205.0, 1210.0, 1211.0, 1212.0, 1213.0,
134 1214.0, 1215.0, 1220.0, 1221.0, 1222.0, 1223.0, 1224.0, 1225.0, 1230.0, 1231.0, 1232.0,
135 1233.0, 1234.0, 1235.0, 1240.0, 1241.0, 1242.0, 1243.0, 1244.0, 1245.0, 2200.0, 2201.0,
136 2202.0, 2203.0, 2204.0, 2205.0, 2210.0, 2211.0, 2212.0, 2213.0, 2214.0, 2215.0, 2220.0,
137 2221.0, 2222.0, 2223.0, 2224.0, 2225.0, 2230.0, 2231.0, 2232.0, 2233.0, 2234.0, 2235.0,
138 2240.0, 2241.0, 2242.0, 2243.0, 2244.0, 2245.0, 300.0, 301.0, 302.0, 303.0, 304.0,
139 305.0, 310.0, 311.0, 312.0, 313.0, 314.0, 315.0, 320.0, 321.0, 322.0, 323.0, 324.0,
140 325.0, 330.0, 331.0, 332.0, 333.0, 334.0, 335.0, 340.0, 341.0, 342.0, 343.0, 344.0,
141 345.0, 1300.0, 1301.0, 1302.0, 1303.0, 1304.0, 1305.0, 1310.0, 1311.0, 1312.0, 1313.0,
142 1314.0, 1315.0, 1320.0, 1321.0, 1322.0, 1323.0, 1324.0, 1325.0, 1330.0, 1331.0, 1332.0,
143 1333.0, 1334.0, 1335.0, 1340.0, 1341.0, 1342.0, 1343.0, 1344.0, 1345.0, 2300.0, 2301.0,
144 2302.0, 2303.0, 2304.0, 2305.0, 2310.0, 2311.0, 2312.0, 2313.0, 2314.0, 2315.0, 2320.0,
145 2321.0, 2322.0, 2323.0, 2324.0, 2325.0, 2330.0, 2331.0, 2332.0, 2333.0, 2334.0, 2335.0,
146 2340.0, 2341.0, 2342.0, 2343.0, 2344.0, 2345.0,
147 ];
148 let ans: Matrix<Owned<f32>, DimDyn, D> = Matrix::from_vec(ans, [4, 3, 5, 6]);
149 assert!((output - ans).asum() < 1e-6);
150 }
151 #[test]
152 fn inplace_transpose_4d_cpu() {
153 inplace_transpose_4d::<crate::device::cpu::Cpu>();
154 }
155 #[cfg(feature = "nvidia")]
156 #[test]
157 fn inplace_transpose_4d_cuda() {
158 inplace_transpose_4d::<crate::device::nvidia::Nvidia>();
159 }
160
161 fn swap_axis<D: Device>() {
162 let input: Matrix<Owned<f32>, DimDyn, D> =
163 Matrix::from_vec(vec![1., 2., 3., 4., 5., 6.], [2, 3]);
164 let output = input.transpose_swap_index_new_matrix(0, 1);
165 let ans: Matrix<Owned<f32>, DimDyn, D> =
166 Matrix::from_vec(vec![1., 4., 2., 5., 3., 6.], [3, 2]);
167 assert!((output - ans).asum() < 1e-6);
168 }
169 #[test]
170 fn swap_axis_cpu() {
171 swap_axis::<crate::device::cpu::Cpu>();
172 }
173 #[cfg(feature = "nvidia")]
174 #[test]
175 fn swap_axis_nvidia() {
176 swap_axis::<crate::device::nvidia::Nvidia>();
177 }
178
179 fn swap_axis_3d<D: Device>() {
181 let input: Matrix<Owned<f32>, DimDyn, D> = Matrix::from_vec(
183 vec![
184 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18.,
185 19., 20., 21., 22., 23., 24.,
186 ],
187 [2, 3, 4],
188 );
189 let output = input.transpose_swap_index_new_matrix(0, 1);
190 let ans: Matrix<Owned<f32>, DimDyn, D> = Matrix::from_vec(
191 vec![
192 1., 2., 3., 4., 13., 14., 15., 16., 5., 6., 7., 8., 17., 18., 19., 20., 9., 10.,
193 11., 12., 21., 22., 23., 24.,
194 ],
195 [3, 2, 4],
196 );
197 assert!((output - ans).asum() < 1e-6);
198 }
199 #[test]
200 fn swap_axis_3d_cpu() {
201 swap_axis_3d::<crate::device::cpu::Cpu>();
202 }
203 #[cfg(feature = "nvidia")]
204 #[test]
205 fn swap_axis_3d_cuda() {
206 swap_axis_3d::<crate::device::nvidia::Nvidia>();
207 }
208}