1use crate::dim::DimTrait;
2use crate::index::{IndexAxisTrait, ShapeStride};
3
4macro_rules! impl_index_axis {
5 ($impl_name:ident, $target_dim:expr) => {
6 #[derive(Copy, Clone, Debug, PartialEq)]
7 pub struct $impl_name(pub usize);
8
9 impl $impl_name {
10 #[must_use]
11 pub fn new(index: usize) -> Self {
12 $impl_name(index)
13 }
14
15 #[must_use]
16 pub fn index(&self) -> usize {
17 self.0
18 }
19
20 #[must_use]
21 pub fn target_dim(&self) -> usize {
22 $target_dim
23 }
24
25 pub fn get_shape_stride<Din: DimTrait, Dout: DimTrait>(
26 &self,
27 shape: &Din,
28 stride: &Din,
29 ) -> ShapeStride<Dout> {
30 let mut shape_v = Vec::new();
31 let mut stride_v = Vec::new();
32 for i in 0..shape.len() {
33 if i == $target_dim {
34 continue;
35 }
36 shape_v.push(shape[i]);
37 stride_v.push(stride[i]);
38 }
39
40 let new_shape = Dout::from(&shape_v as &[usize]);
41 let new_stride = Dout::from(&stride_v as &[usize]);
42 ShapeStride::new(new_shape, new_stride)
43 }
44
45 pub fn get_offset<D: DimTrait>(&self, stride: D) -> usize {
46 stride[$target_dim] * self.0
47 }
48 }
49 };
50}
51impl_index_axis!(Index0D, 0);
52impl_index_axis!(Index1D, 1);
53impl_index_axis!(Index2D, 2);
54impl_index_axis!(Index3D, 3);
55
56macro_rules! impl_index_axis_trait {
57 ($impl_trait:ident) => {
58 impl IndexAxisTrait for $impl_trait {
59 fn get_shape_stride<Din: DimTrait, Dout: DimTrait>(
60 &self,
61 shape: Din,
62 stride: Din,
63 ) -> ShapeStride<Dout> {
64 self.get_shape_stride::<Din, Dout>(&shape, &stride)
65 }
66
67 fn offset<Din: DimTrait>(&self, stride: Din) -> usize {
68 self.get_offset::<Din>(stride.clone())
69 }
70 }
71 };
72}
73impl_index_axis_trait!(Index0D);
74impl_index_axis_trait!(Index1D);
75impl_index_axis_trait!(Index2D);
76impl_index_axis_trait!(Index3D);
77
78#[cfg(test)]
79mod index_xd {
80 use super::{Index0D, Index1D, Index2D, Index3D};
81 use crate::dim::{Dim1, Dim2, Dim3, Dim4};
82
83 #[test]
84 fn offset_1d() {
85 let stride = Dim1::new([1]);
86 let index = Index0D::new(1);
87 let offset = index.get_offset(stride);
88 assert_eq!(offset, 1);
89 }
90
91 #[test]
92 fn offset_2d() {
93 let stride = Dim2::new([4, 1]);
94 let index = Index0D::new(2);
95 let offset = index.get_offset(stride);
96 assert_eq!(offset, 8);
97 }
98
99 #[test]
100 fn offset_3d() {
101 let stride = Dim3::new([20, 5, 1]);
102 let index = Index0D::new(2);
103 let offset = index.get_offset(stride);
104 assert_eq!(offset, 40);
105 }
106
107 #[test]
108 fn shape_stride_2d_index_0() {
109 let shape = Dim2::new([3, 4]);
110 let stride = Dim2::new([4, 1]);
111
112 let index = Index0D::new(1);
113
114 let shape_stride = index.get_shape_stride::<Dim2, Dim1>(&shape, &stride);
115
116 assert_eq!(shape_stride.shape(), Dim1::new([4]));
117 assert_eq!(shape_stride.stride(), Dim1::new([1]));
118 }
119
120 #[test]
121 fn shape_stride_2d_index_1() {
122 let shape = Dim2::new([3, 4]);
123 let stride = Dim2::new([4, 1]);
124
125 let index = Index1D::new(1);
126
127 let shape_stride = index.get_shape_stride::<Dim2, Dim1>(&shape, &stride);
128
129 assert_eq!(shape_stride.shape(), Dim1::new([3]));
130 assert_eq!(shape_stride.stride(), Dim1::new([4]));
131 }
132
133 #[test]
134 fn shape_stride_3d_index_0() {
135 let shape = Dim3::new([3, 4, 5]);
136 let stride = Dim3::new([20, 5, 1]);
137
138 let index = Index0D::new(1);
139
140 let shape_stride = index.get_shape_stride::<Dim3, Dim2>(&shape, &stride);
141
142 assert_eq!(shape_stride.shape(), Dim2::new([4, 5]));
143 assert_eq!(shape_stride.stride(), Dim2::new([5, 1]));
144 }
145
146 #[test]
147 fn shape_stride_3d_index_1() {
148 let shape = Dim3::new([3, 4, 5]);
149 let stride = Dim3::new([20, 5, 1]);
150
151 let index = Index1D::new(1);
152
153 let shape_stride = index.get_shape_stride::<Dim3, Dim2>(&shape, &stride);
154
155 assert_eq!(shape_stride.shape(), Dim2::new([3, 5]));
156 assert_eq!(shape_stride.stride(), Dim2::new([20, 1]));
157 }
158
159 #[test]
160 fn shape_stride_3d_index_2() {
161 let shape = Dim3::new([3, 4, 5]);
162 let stride = Dim3::new([20, 5, 1]);
163
164 let index = Index2D::new(1);
165
166 let shape_stride = index.get_shape_stride::<Dim3, Dim2>(&shape, &stride);
167
168 assert_eq!(shape_stride.shape(), Dim2::new([3, 4]));
169 assert_eq!(shape_stride.stride(), Dim2::new([20, 5]));
170 }
171
172 #[test]
173 fn shape_stride_4d_index_0() {
174 let shape = Dim4::new([3, 4, 5, 6]);
175 let stride = Dim4::new([120, 30, 6, 1]);
176
177 let index = Index0D::new(1);
178
179 let shape_stride = index.get_shape_stride::<Dim4, Dim3>(&shape, &stride);
180
181 assert_eq!(shape_stride.shape(), Dim3::new([4, 5, 6]));
182 assert_eq!(shape_stride.stride(), Dim3::new([30, 6, 1]));
183 }
184
185 #[test]
186 fn shape_stride_4d_index_1() {
187 let shape = Dim4::new([3, 4, 5, 6]);
188 let stride = Dim4::new([120, 30, 6, 1]);
189
190 let index = Index1D::new(1);
191
192 let shape_stride = index.get_shape_stride::<Dim4, Dim3>(&shape, &stride);
193
194 assert_eq!(shape_stride.shape(), Dim3::new([3, 5, 6]));
195 assert_eq!(shape_stride.stride(), Dim3::new([120, 6, 1]));
196 }
197
198 #[test]
199 fn shape_stride_4d_index_2() {
200 let shape = Dim4::new([3, 4, 5, 6]);
201 let stride = Dim4::new([120, 30, 6, 1]);
202
203 let index = Index2D::new(1);
204
205 let shape_stride = index.get_shape_stride::<Dim4, Dim3>(&shape, &stride);
206
207 assert_eq!(shape_stride.shape(), Dim3::new([3, 4, 6]));
208 assert_eq!(shape_stride.stride(), Dim3::new([120, 30, 1]));
209 }
210
211 #[test]
212 fn shape_stride_4d_index_3() {
213 let shape = Dim4::new([3, 4, 5, 6]);
214 let stride = Dim4::new([120, 30, 6, 1]);
215
216 let index = Index3D::new(1);
217
218 let shape_stride = index.get_shape_stride::<Dim4, Dim3>(&shape, &stride);
219
220 assert_eq!(shape_stride.shape(), Dim3::new([3, 4, 5]));
221 assert_eq!(shape_stride.stride(), Dim3::new([120, 30, 6]));
222 }
223}