zenu_matrix/operation/
reshape.rs

1use crate::{
2    device::Device,
3    dim::{default_stride, DimDyn, DimTrait},
4    matrix::{Matrix, Owned, Ref, Repr},
5    num::Num,
6    shape_stride::ShapeStride,
7};
8
9impl<T: Num, R: Repr<Item = T>, S: DimTrait, D: Device> Matrix<R, S, D> {
10    #[expect(clippy::missing_panics_doc)]
11    pub fn reshape<I: Into<DimDyn>>(&self, new_shape: I) -> Matrix<Ref<&T>, DimDyn, D> {
12        let new_shape = new_shape.into();
13        assert_eq!(
14            self.shape().num_elm(),
15            new_shape.num_elm(),
16            "Number of elements must be the same"
17        );
18        assert!(
19            self.shape_stride().is_default_stride(),
20            r#"""
21`reshape` method is not alloc new memory.
22So, This matrix is not default stride, it is not allowed to use `reshape` method.
23Use `reshape_new_matrix` method instead.
24            """#
25        );
26        let new_stride = default_stride(new_shape);
27        let mut result = self.to_ref().into_dyn_dim();
28        result.update_shape_stride(ShapeStride::new(new_shape, new_stride));
29        result
30    }
31}
32
33impl<T: Num, R: Repr<Item = T>, D: Device> Matrix<R, DimDyn, D> {
34    #[expect(clippy::missing_panics_doc)]
35    pub fn reshape_new_matrix<I: Into<DimDyn>>(&self, new_shape: I) -> Matrix<Owned<T>, DimDyn, D> {
36        let new_shape = new_shape.into();
37        assert_eq!(
38            self.shape().num_elm(),
39            new_shape.num_elm(),
40            "Number of elements must be the same"
41        );
42        let new_stride = default_stride(new_shape);
43
44        let mut default_stride_matrix = self.to_ref().to_default_stride();
45        default_stride_matrix.update_shape_stride(ShapeStride::new(new_shape, new_stride));
46        default_stride_matrix
47    }
48}
49impl<T: Num, S: DimTrait, D: Device> Matrix<Owned<T>, S, D> {
50    #[expect(clippy::missing_panics_doc)]
51    pub fn reshape_mut<I: Into<DimDyn>>(&mut self, new_shape: I) -> Matrix<Ref<&mut T>, DimDyn, D> {
52        let new_shape = new_shape.into();
53        assert_eq!(
54            self.shape().num_elm(),
55            new_shape.num_elm(),
56            "Number of elements must be the same"
57        );
58        assert!(
59            self.shape_stride().is_default_stride(),
60            r#"""
61`reshape` method is not alloc new memory.
62So, This matrix is not default stride, it is not allowed to use `reshape` method.
63Use `reshape_new_matrix` method instead.
64            """#
65        );
66        let new_stride = default_stride(new_shape);
67        let mut result = self.to_ref_mut().into_dyn_dim();
68        result.update_shape_stride(ShapeStride::new(new_shape, new_stride));
69        result
70    }
71}
72
73impl<T: Num, S: DimTrait, D: Device> Matrix<Owned<T>, S, D> {
74    #[expect(clippy::missing_panics_doc)]
75    pub fn reshape_no_alloc_owned<I: Into<DimDyn>>(
76        self,
77        new_shape: I,
78    ) -> Matrix<Owned<T>, DimDyn, D> {
79        let new_shape = new_shape.into();
80        assert_eq!(
81            self.shape().num_elm(),
82            new_shape.num_elm(),
83            "Number of elements must be the same"
84        );
85        assert!(
86            self.shape_stride().is_default_stride(),
87            r#"""
88`reshape` method is not alloc new memory.
89So, This matrix is not default stride, it is not allowed to use `reshape` method.
90Use `reshape_new_matrix` method instead.
91            """#
92        );
93        let mut s = self.into_dyn_dim();
94        let new_shape_stride = ShapeStride::new(new_shape, default_stride(new_shape));
95        s.update_shape_stride(new_shape_stride);
96        s
97    }
98}
99
100#[cfg(test)]
101mod reshape {
102    use crate::{
103        device::Device,
104        dim::{DimDyn, DimTrait},
105        matrix::{Matrix, Owned},
106    };
107
108    fn reshape_3d_1d<D: Device>() {
109        let a = Matrix::<Owned<f32>, DimDyn, D>::from_vec(
110            vec![
111                1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18.,
112            ],
113            [2, 3, 3],
114        );
115        let b = a.reshape([18]);
116        let ans = Matrix::<Owned<f32>, DimDyn, D>::from_vec(
117            vec![
118                1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18.,
119            ],
120            [18],
121        );
122        assert_eq!(b.shape().slice(), ans.shape().slice());
123        assert!((b - ans).to_ref().asum() < 1e-6);
124    }
125    #[test]
126    fn reshape_3d_1d_cpu() {
127        reshape_3d_1d::<crate::device::cpu::Cpu>();
128    }
129    #[cfg(feature = "nvidia")]
130    #[test]
131    fn reshape_3d_1d_gpu() {
132        reshape_3d_1d::<crate::device::nvidia::Nvidia>();
133    }
134
135    fn reshape_1d_3d<D: Device>() {
136        let a = Matrix::<Owned<f32>, DimDyn, D>::from_vec(vec![1., 2., 3., 4., 5., 6.], [6]);
137        let b = a.reshape([2, 3, 1]);
138        let ans =
139            Matrix::<Owned<f32>, DimDyn, D>::from_vec(vec![1., 2., 3., 4., 5., 6.], [2, 3, 1]);
140        assert_eq!(b.shape().slice(), ans.shape().slice());
141        assert!((b - ans).to_ref().asum() < 1e-6);
142    }
143    #[test]
144    fn reshape_1d_3d_cpu() {
145        reshape_1d_3d::<crate::device::cpu::Cpu>();
146    }
147    #[cfg(feature = "nvidia")]
148    #[test]
149    fn reshape_1d_3d_gpu() {
150        reshape_1d_3d::<crate::device::nvidia::Nvidia>();
151    }
152
153    // #[test]
154    fn reshape_new_matrix_3d_1d<D: Device>() {
155        let mut a = Matrix::<Owned<f32>, DimDyn, D>::from_vec(
156            vec![
157                1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18.,
158            ],
159            [2, 3, 3],
160        );
161        let mut a_ref_mut = a.to_ref_mut();
162        a_ref_mut.transpose_by_index(&[2, 1, 0]);
163        let b = a_ref_mut.reshape_new_matrix([18]);
164        let ans = Matrix::<Owned<f32>, DimDyn, D>::from_vec(
165            vec![
166                1., 10., 4., 13., 7., 16., 2., 11., 5., 14., 8., 17., 3., 12., 6., 15., 9., 18.,
167            ],
168            [18],
169        );
170        assert_eq!(b.shape().slice(), ans.shape().slice());
171        assert!((b - ans).to_ref().asum() < 1e-6);
172    }
173    #[test]
174    fn reshape_new_matrix_3d_1d_cpu() {
175        reshape_new_matrix_3d_1d::<crate::device::cpu::Cpu>();
176    }
177    #[cfg(feature = "nvidia")]
178    #[test]
179    fn reshape_new_matrix_3d_1d_gpu() {
180        reshape_new_matrix_3d_1d::<crate::device::nvidia::Nvidia>();
181    }
182}