zenu_matrix/index/
index_impl.rs

1use crate::dim::DimTrait;
2use crate::index::{IndexAxisTrait, ShapeStride};
3
4macro_rules! impl_index_axis {
5    ($impl_name:ident, $target_dim:expr) => {
6        #[derive(Copy, Clone, Debug, PartialEq)]
7        pub struct $impl_name(pub usize);
8
9        impl $impl_name {
10            #[must_use]
11            pub fn new(index: usize) -> Self {
12                $impl_name(index)
13            }
14
15            #[must_use]
16            pub fn index(&self) -> usize {
17                self.0
18            }
19
20            #[must_use]
21            pub fn target_dim(&self) -> usize {
22                $target_dim
23            }
24
25            pub fn get_shape_stride<Din: DimTrait, Dout: DimTrait>(
26                &self,
27                shape: &Din,
28                stride: &Din,
29            ) -> ShapeStride<Dout> {
30                let mut shape_v = Vec::new();
31                let mut stride_v = Vec::new();
32                for i in 0..shape.len() {
33                    if i == $target_dim {
34                        continue;
35                    }
36                    shape_v.push(shape[i]);
37                    stride_v.push(stride[i]);
38                }
39
40                let new_shape = Dout::from(&shape_v as &[usize]);
41                let new_stride = Dout::from(&stride_v as &[usize]);
42                ShapeStride::new(new_shape, new_stride)
43            }
44
45            pub fn get_offset<D: DimTrait>(&self, stride: D) -> usize {
46                stride[$target_dim] * self.0
47            }
48        }
49    };
50}
51impl_index_axis!(Index0D, 0);
52impl_index_axis!(Index1D, 1);
53impl_index_axis!(Index2D, 2);
54impl_index_axis!(Index3D, 3);
55
56macro_rules! impl_index_axis_trait {
57    ($impl_trait:ident) => {
58        impl IndexAxisTrait for $impl_trait {
59            fn get_shape_stride<Din: DimTrait, Dout: DimTrait>(
60                &self,
61                shape: Din,
62                stride: Din,
63            ) -> ShapeStride<Dout> {
64                self.get_shape_stride::<Din, Dout>(&shape, &stride)
65            }
66
67            fn offset<Din: DimTrait>(&self, stride: Din) -> usize {
68                self.get_offset::<Din>(stride.clone())
69            }
70        }
71    };
72}
73impl_index_axis_trait!(Index0D);
74impl_index_axis_trait!(Index1D);
75impl_index_axis_trait!(Index2D);
76impl_index_axis_trait!(Index3D);
77
78#[cfg(test)]
79mod index_xd {
80    use super::{Index0D, Index1D, Index2D, Index3D};
81    use crate::dim::{Dim1, Dim2, Dim3, Dim4};
82
83    #[test]
84    fn offset_1d() {
85        let stride = Dim1::new([1]);
86        let index = Index0D::new(1);
87        let offset = index.get_offset(stride);
88        assert_eq!(offset, 1);
89    }
90
91    #[test]
92    fn offset_2d() {
93        let stride = Dim2::new([4, 1]);
94        let index = Index0D::new(2);
95        let offset = index.get_offset(stride);
96        assert_eq!(offset, 8);
97    }
98
99    #[test]
100    fn offset_3d() {
101        let stride = Dim3::new([20, 5, 1]);
102        let index = Index0D::new(2);
103        let offset = index.get_offset(stride);
104        assert_eq!(offset, 40);
105    }
106
107    #[test]
108    fn shape_stride_2d_index_0() {
109        let shape = Dim2::new([3, 4]);
110        let stride = Dim2::new([4, 1]);
111
112        let index = Index0D::new(1);
113
114        let shape_stride = index.get_shape_stride::<Dim2, Dim1>(&shape, &stride);
115
116        assert_eq!(shape_stride.shape(), Dim1::new([4]));
117        assert_eq!(shape_stride.stride(), Dim1::new([1]));
118    }
119
120    #[test]
121    fn shape_stride_2d_index_1() {
122        let shape = Dim2::new([3, 4]);
123        let stride = Dim2::new([4, 1]);
124
125        let index = Index1D::new(1);
126
127        let shape_stride = index.get_shape_stride::<Dim2, Dim1>(&shape, &stride);
128
129        assert_eq!(shape_stride.shape(), Dim1::new([3]));
130        assert_eq!(shape_stride.stride(), Dim1::new([4]));
131    }
132
133    #[test]
134    fn shape_stride_3d_index_0() {
135        let shape = Dim3::new([3, 4, 5]);
136        let stride = Dim3::new([20, 5, 1]);
137
138        let index = Index0D::new(1);
139
140        let shape_stride = index.get_shape_stride::<Dim3, Dim2>(&shape, &stride);
141
142        assert_eq!(shape_stride.shape(), Dim2::new([4, 5]));
143        assert_eq!(shape_stride.stride(), Dim2::new([5, 1]));
144    }
145
146    #[test]
147    fn shape_stride_3d_index_1() {
148        let shape = Dim3::new([3, 4, 5]);
149        let stride = Dim3::new([20, 5, 1]);
150
151        let index = Index1D::new(1);
152
153        let shape_stride = index.get_shape_stride::<Dim3, Dim2>(&shape, &stride);
154
155        assert_eq!(shape_stride.shape(), Dim2::new([3, 5]));
156        assert_eq!(shape_stride.stride(), Dim2::new([20, 1]));
157    }
158
159    #[test]
160    fn shape_stride_3d_index_2() {
161        let shape = Dim3::new([3, 4, 5]);
162        let stride = Dim3::new([20, 5, 1]);
163
164        let index = Index2D::new(1);
165
166        let shape_stride = index.get_shape_stride::<Dim3, Dim2>(&shape, &stride);
167
168        assert_eq!(shape_stride.shape(), Dim2::new([3, 4]));
169        assert_eq!(shape_stride.stride(), Dim2::new([20, 5]));
170    }
171
172    #[test]
173    fn shape_stride_4d_index_0() {
174        let shape = Dim4::new([3, 4, 5, 6]);
175        let stride = Dim4::new([120, 30, 6, 1]);
176
177        let index = Index0D::new(1);
178
179        let shape_stride = index.get_shape_stride::<Dim4, Dim3>(&shape, &stride);
180
181        assert_eq!(shape_stride.shape(), Dim3::new([4, 5, 6]));
182        assert_eq!(shape_stride.stride(), Dim3::new([30, 6, 1]));
183    }
184
185    #[test]
186    fn shape_stride_4d_index_1() {
187        let shape = Dim4::new([3, 4, 5, 6]);
188        let stride = Dim4::new([120, 30, 6, 1]);
189
190        let index = Index1D::new(1);
191
192        let shape_stride = index.get_shape_stride::<Dim4, Dim3>(&shape, &stride);
193
194        assert_eq!(shape_stride.shape(), Dim3::new([3, 5, 6]));
195        assert_eq!(shape_stride.stride(), Dim3::new([120, 6, 1]));
196    }
197
198    #[test]
199    fn shape_stride_4d_index_2() {
200        let shape = Dim4::new([3, 4, 5, 6]);
201        let stride = Dim4::new([120, 30, 6, 1]);
202
203        let index = Index2D::new(1);
204
205        let shape_stride = index.get_shape_stride::<Dim4, Dim3>(&shape, &stride);
206
207        assert_eq!(shape_stride.shape(), Dim3::new([3, 4, 6]));
208        assert_eq!(shape_stride.stride(), Dim3::new([120, 30, 1]));
209    }
210
211    #[test]
212    fn shape_stride_4d_index_3() {
213        let shape = Dim4::new([3, 4, 5, 6]);
214        let stride = Dim4::new([120, 30, 6, 1]);
215
216        let index = Index3D::new(1);
217
218        let shape_stride = index.get_shape_stride::<Dim4, Dim3>(&shape, &stride);
219
220        assert_eq!(shape_stride.shape(), Dim3::new([3, 4, 5]));
221        assert_eq!(shape_stride.stride(), Dim3::new([120, 30, 6]));
222    }
223}