1use crate::{
2 device::Device,
3 dim::{default_stride, DimDyn, DimTrait},
4 matrix::{Matrix, Owned, Ref, Repr},
5 num::Num,
6 shape_stride::ShapeStride,
7};
8
9impl<T: Num, R: Repr<Item = T>, S: DimTrait, D: Device> Matrix<R, S, D> {
10 #[expect(clippy::missing_panics_doc)]
11 pub fn reshape<I: Into<DimDyn>>(&self, new_shape: I) -> Matrix<Ref<&T>, DimDyn, D> {
12 let new_shape = new_shape.into();
13 assert_eq!(
14 self.shape().num_elm(),
15 new_shape.num_elm(),
16 "Number of elements must be the same"
17 );
18 assert!(
19 self.shape_stride().is_default_stride(),
20 r#"""
21`reshape` method is not alloc new memory.
22So, This matrix is not default stride, it is not allowed to use `reshape` method.
23Use `reshape_new_matrix` method instead.
24 """#
25 );
26 let new_stride = default_stride(new_shape);
27 let mut result = self.to_ref().into_dyn_dim();
28 result.update_shape_stride(ShapeStride::new(new_shape, new_stride));
29 result
30 }
31}
32
33impl<T: Num, R: Repr<Item = T>, D: Device> Matrix<R, DimDyn, D> {
34 #[expect(clippy::missing_panics_doc)]
35 pub fn reshape_new_matrix<I: Into<DimDyn>>(&self, new_shape: I) -> Matrix<Owned<T>, DimDyn, D> {
36 let new_shape = new_shape.into();
37 assert_eq!(
38 self.shape().num_elm(),
39 new_shape.num_elm(),
40 "Number of elements must be the same"
41 );
42 let new_stride = default_stride(new_shape);
43
44 let mut default_stride_matrix = self.to_ref().to_default_stride();
45 default_stride_matrix.update_shape_stride(ShapeStride::new(new_shape, new_stride));
46 default_stride_matrix
47 }
48}
49impl<T: Num, S: DimTrait, D: Device> Matrix<Owned<T>, S, D> {
50 #[expect(clippy::missing_panics_doc)]
51 pub fn reshape_mut<I: Into<DimDyn>>(&mut self, new_shape: I) -> Matrix<Ref<&mut T>, DimDyn, D> {
52 let new_shape = new_shape.into();
53 assert_eq!(
54 self.shape().num_elm(),
55 new_shape.num_elm(),
56 "Number of elements must be the same"
57 );
58 assert!(
59 self.shape_stride().is_default_stride(),
60 r#"""
61`reshape` method is not alloc new memory.
62So, This matrix is not default stride, it is not allowed to use `reshape` method.
63Use `reshape_new_matrix` method instead.
64 """#
65 );
66 let new_stride = default_stride(new_shape);
67 let mut result = self.to_ref_mut().into_dyn_dim();
68 result.update_shape_stride(ShapeStride::new(new_shape, new_stride));
69 result
70 }
71}
72
73impl<T: Num, S: DimTrait, D: Device> Matrix<Owned<T>, S, D> {
74 #[expect(clippy::missing_panics_doc)]
75 pub fn reshape_no_alloc_owned<I: Into<DimDyn>>(
76 self,
77 new_shape: I,
78 ) -> Matrix<Owned<T>, DimDyn, D> {
79 let new_shape = new_shape.into();
80 assert_eq!(
81 self.shape().num_elm(),
82 new_shape.num_elm(),
83 "Number of elements must be the same"
84 );
85 assert!(
86 self.shape_stride().is_default_stride(),
87 r#"""
88`reshape` method is not alloc new memory.
89So, This matrix is not default stride, it is not allowed to use `reshape` method.
90Use `reshape_new_matrix` method instead.
91 """#
92 );
93 let mut s = self.into_dyn_dim();
94 let new_shape_stride = ShapeStride::new(new_shape, default_stride(new_shape));
95 s.update_shape_stride(new_shape_stride);
96 s
97 }
98}
99
100#[cfg(test)]
101mod reshape {
102 use crate::{
103 device::Device,
104 dim::{DimDyn, DimTrait},
105 matrix::{Matrix, Owned},
106 };
107
108 fn reshape_3d_1d<D: Device>() {
109 let a = Matrix::<Owned<f32>, DimDyn, D>::from_vec(
110 vec![
111 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18.,
112 ],
113 [2, 3, 3],
114 );
115 let b = a.reshape([18]);
116 let ans = Matrix::<Owned<f32>, DimDyn, D>::from_vec(
117 vec![
118 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18.,
119 ],
120 [18],
121 );
122 assert_eq!(b.shape().slice(), ans.shape().slice());
123 assert!((b - ans).to_ref().asum() < 1e-6);
124 }
125 #[test]
126 fn reshape_3d_1d_cpu() {
127 reshape_3d_1d::<crate::device::cpu::Cpu>();
128 }
129 #[cfg(feature = "nvidia")]
130 #[test]
131 fn reshape_3d_1d_gpu() {
132 reshape_3d_1d::<crate::device::nvidia::Nvidia>();
133 }
134
135 fn reshape_1d_3d<D: Device>() {
136 let a = Matrix::<Owned<f32>, DimDyn, D>::from_vec(vec![1., 2., 3., 4., 5., 6.], [6]);
137 let b = a.reshape([2, 3, 1]);
138 let ans =
139 Matrix::<Owned<f32>, DimDyn, D>::from_vec(vec![1., 2., 3., 4., 5., 6.], [2, 3, 1]);
140 assert_eq!(b.shape().slice(), ans.shape().slice());
141 assert!((b - ans).to_ref().asum() < 1e-6);
142 }
143 #[test]
144 fn reshape_1d_3d_cpu() {
145 reshape_1d_3d::<crate::device::cpu::Cpu>();
146 }
147 #[cfg(feature = "nvidia")]
148 #[test]
149 fn reshape_1d_3d_gpu() {
150 reshape_1d_3d::<crate::device::nvidia::Nvidia>();
151 }
152
153 fn reshape_new_matrix_3d_1d<D: Device>() {
155 let mut a = Matrix::<Owned<f32>, DimDyn, D>::from_vec(
156 vec![
157 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18.,
158 ],
159 [2, 3, 3],
160 );
161 let mut a_ref_mut = a.to_ref_mut();
162 a_ref_mut.transpose_by_index(&[2, 1, 0]);
163 let b = a_ref_mut.reshape_new_matrix([18]);
164 let ans = Matrix::<Owned<f32>, DimDyn, D>::from_vec(
165 vec![
166 1., 10., 4., 13., 7., 16., 2., 11., 5., 14., 8., 17., 3., 12., 6., 15., 9., 18.,
167 ],
168 [18],
169 );
170 assert_eq!(b.shape().slice(), ans.shape().slice());
171 assert!((b - ans).to_ref().asum() < 1e-6);
172 }
173 #[test]
174 fn reshape_new_matrix_3d_1d_cpu() {
175 reshape_new_matrix_3d_1d::<crate::device::cpu::Cpu>();
176 }
177 #[cfg(feature = "nvidia")]
178 #[test]
179 fn reshape_new_matrix_3d_1d_gpu() {
180 reshape_new_matrix_3d_1d::<crate::device::nvidia::Nvidia>();
181 }
182}