1use crate::{
2 device::Device,
3 dim::{cal_offset, DimDyn, DimTrait},
4 matrix::{Matrix, Owned, Ref},
5 num::Num,
6 shape_stride::ShapeStride,
7};
8
9struct MapAxis<'a, T, F, D>
10where
11 T: Num,
12 F: FnMut(Matrix<Ref<&'a mut T>, DimDyn, D>),
13 D: Device,
14{
15 matrix: Matrix<Ref<&'a mut T>, DimDyn, D>,
16 axis: usize,
17 fn_map: F,
18}
19
20impl<'a, T, F, D> MapAxis<'a, T, F, D>
21where
22 T: Num,
23 F: FnMut(Matrix<Ref<&'a mut T>, DimDyn, D>),
24 D: Device,
25{
26 fn new(matrix: Matrix<Ref<&'a mut T>, DimDyn, D>, axis: usize, fn_map: F) -> Self {
27 Self {
28 matrix,
29 axis,
30 fn_map,
31 }
32 }
33
34 fn target_shape_stride(&self) -> ShapeStride<DimDyn> {
35 let sh = self.target_shape();
36 let st = self.target_stride();
37 ShapeStride::new(DimDyn::from([sh]), DimDyn::from([st]))
38 }
39
40 fn target_stride(&self) -> usize {
41 self.matrix.stride()[self.axis]
42 }
43
44 fn target_shape(&self) -> usize {
45 self.matrix.shape()[self.axis]
46 }
47
48 fn get_index(&self) -> Vec<DimDyn> {
50 let mut candidates = Vec::with_capacity(self.matrix.shape().len() - 1);
52 for (i, s) in self.matrix.shape().into_iter().enumerate() {
53 if i != self.axis {
54 candidates.push(s);
55 }
56 }
57
58 let combinations = generate_combinations(&candidates);
59 combinations
60 .into_iter()
61 .map(|c| {
62 let mut c = c;
63 c.insert(self.axis, 0);
64 DimDyn::from(c.as_slice())
65 })
66 .collect()
67 }
68
69 fn get_offsets(&self) -> Vec<usize> {
70 self.get_index()
71 .into_iter()
72 .map(|itm| cal_offset(self.matrix.stride(), itm))
73 .collect()
74 }
75
76 fn apply(&mut self) {
77 let shapt_stride = self.target_shape_stride();
78 for offset in self.get_offsets() {
79 let view = self.matrix.offset_ptr_mut(offset);
81 let matrix = Matrix::new(view, shapt_stride.shape(), shapt_stride.stride());
82 (self.fn_map)(matrix);
83 }
84 }
85}
86
87fn generate_combinations(nums: &[usize]) -> Vec<Vec<usize>> {
88 fn recurse(
89 index: usize,
90 current: &mut Vec<usize>,
91 nums: &[usize],
92 result: &mut Vec<Vec<usize>>,
93 ) {
94 if index == nums.len() {
95 result.push(current.clone());
96 return;
97 }
98
99 for i in 0..nums[index] {
100 current.push(i);
101 recurse(index + 1, current, nums, result);
102 current.pop();
103 }
104 }
105
106 let mut result = Vec::new();
107 recurse(0, &mut Vec::new(), nums, &mut result);
108 result
109}
110
111pub trait MatrixIter<T: Num, D: Device> {
112 fn map_axis<F>(&self, axis: usize, fn_map: F) -> Matrix<Owned<T>, DimDyn, D>
113 where
114 F: FnMut(Matrix<Ref<&mut T>, DimDyn, D>);
115 fn map_axis_mut<F>(self, axis: usize, fn_map: F)
116 where
117 F: FnMut(Matrix<Ref<&mut T>, DimDyn, D>);
118}
119
120impl<T: Num, D: Device> MatrixIter<T, D> for Matrix<Ref<&mut T>, DimDyn, D> {
121 fn map_axis<F>(&self, axis: usize, fn_map: F) -> Matrix<Owned<T>, DimDyn, D>
122 where
123 F: FnMut(Matrix<Ref<&mut T>, DimDyn, D>),
124 {
125 let mut ans = Matrix::<_, DimDyn, D>::zeros(self.shape());
126 ans.to_ref_mut().copy_from(self);
127 ans.to_ref_mut().map_axis_mut(axis, fn_map);
128 ans
129 }
130
131 fn map_axis_mut<F>(self, axis: usize, fn_map: F)
132 where
133 F: FnMut(Matrix<Ref<&mut T>, DimDyn, D>),
134 {
135 assert!(self.shape().len() > 1, "shape.len() <= 1");
139 let mut map_axis = MapAxis::new(self, axis, fn_map);
140 map_axis.apply();
141 }
142}
143
144#[expect(clippy::float_cmp)]
145#[cfg(test)]
146mod map_axis {
147 use crate::{
148 device::Device,
149 dim::DimDyn,
150 matrix::{Matrix, Owned},
151 matrix_iter::MatrixIter,
152 };
153
154 fn test_2d_0<D: Device>() {
155 let mut a = Matrix::<Owned<f32>, DimDyn, D>::from_vec(vec![1., 2., 3., 4., 5., 6.], [2, 3]);
156 a.to_ref_mut().map_axis_mut(0, |m| {
157 let ans = Matrix::<Owned<f32>, DimDyn, D>::from_vec(vec![2., 1.], [2]);
158 m.copy_from(&ans);
159 });
160
161 let ans = Matrix::<Owned<f32>, DimDyn, D>::from_vec(vec![2., 2., 2., 1., 1., 1.], [2, 3]);
162 let diff = ans - a;
163 let diff = diff.asum();
164 assert_eq!(diff, 0.);
165 }
166 #[test]
167 fn test_2d_0_cpu() {
168 test_2d_0::<crate::device::cpu::Cpu>();
169 }
170 #[cfg(feature = "nvidia")]
171 #[test]
172 fn test_2d_0_cuda() {
173 test_2d_0::<crate::device::nvidia::Nvidia>();
174 }
175
176 fn test_2d_1<D: Device>() {
177 let mut a = Matrix::<Owned<f32>, DimDyn, D>::from_vec(vec![1., 2., 3., 4., 5., 6.], [2, 3]);
178 a.to_ref_mut().map_axis_mut(1, |m| {
179 let ans = Matrix::<Owned<f32>, DimDyn, D>::from_vec(vec![3., 2., 1.], [3]);
180 m.copy_from(&ans);
181 });
182
183 let ans = Matrix::<Owned<f32>, DimDyn, D>::from_vec(vec![3., 2., 1., 3., 2., 1.], [2, 3]);
184 let diff = ans - a;
185 let diff = diff.asum();
186 assert_eq!(diff, 0.);
187 }
188 #[test]
189 fn test_2d_1_cpu() {
190 test_2d_1::<crate::device::cpu::Cpu>();
191 }
192 #[cfg(feature = "nvidia")]
193 #[test]
194 fn test_2d_1_cuda() {
195 test_2d_1::<crate::device::nvidia::Nvidia>();
196 }
197
198 fn test_3d_0<D: Device>() {
199 let mut a = Matrix::<Owned<f32>, DimDyn, D>::from_vec(
200 vec![1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.],
201 [2, 2, 3],
202 );
203
204 a.to_ref_mut().map_axis_mut(0, |m| {
205 let ans = Matrix::<Owned<f32>, DimDyn, D>::from_vec(vec![2., 1.], [2]);
206 m.copy_from(&ans);
207 });
208
209 let ans = Matrix::<Owned<f32>, DimDyn, D>::from_vec(
210 vec![2., 2., 2., 2., 2., 2., 1., 1., 1., 1., 1., 1.],
211 [2, 2, 3],
212 );
213
214 let diff = ans - a;
215 let diff = diff.asum();
216 assert_eq!(diff, 0.);
217 }
218 #[test]
219 fn test_3d_0_cpu() {
220 test_3d_0::<crate::device::cpu::Cpu>();
221 }
222 #[cfg(feature = "nvidia")]
223 #[test]
224 fn test_3d_0_cuda() {
225 test_3d_0::<crate::device::nvidia::Nvidia>();
226 }
227
228 fn test_3d_1<D: Device>() {
229 let mut a = Matrix::<Owned<f32>, DimDyn, D>::from_vec(
230 vec![1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.],
231 [2, 2, 3],
232 );
233
234 a.to_ref_mut().map_axis_mut(1, |m| {
235 let ans = Matrix::<Owned<f32>, DimDyn, D>::from_vec(vec![2., 1.], [2]);
236 m.copy_from(&ans);
237 });
238
239 let ans = Matrix::<Owned<f32>, DimDyn, D>::from_vec(
240 vec![2., 2., 2., 1., 1., 1., 2., 2., 2., 1., 1., 1.],
241 [2, 2, 3],
242 );
243
244 let diff = ans - a;
245 let diff = diff.asum();
246 assert_eq!(diff, 0.);
247 }
248 #[test]
249 fn test_3d_1_cpu() {
250 test_3d_1::<crate::device::cpu::Cpu>();
251 }
252 #[cfg(feature = "nvidia")]
253 #[test]
254 fn test_3d_1_cuda() {
255 test_3d_1::<crate::device::nvidia::Nvidia>();
256 }
257
258 fn test_3d_2<D: Device>() {
259 let mut a = Matrix::<Owned<f32>, DimDyn, D>::from_vec(
260 vec![1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.],
261 [2, 2, 3],
262 );
263
264 a.to_ref_mut().map_axis_mut(2, |m| {
265 let ans = Matrix::<Owned<f32>, DimDyn, D>::from_vec(vec![3., 2., 1.], [3]);
266 m.copy_from(&ans);
267 });
268
269 let ans = Matrix::<Owned<f32>, DimDyn, D>::from_vec(
270 vec![3., 2., 1., 3., 2., 1., 3., 2., 1., 3., 2., 1.],
271 [2, 2, 3],
272 );
273
274 let diff = ans - a;
275 let diff = diff.asum();
276 assert_eq!(diff, 0.);
277 }
278 #[test]
279 fn test_3d_2_cpu() {
280 test_3d_2::<crate::device::cpu::Cpu>();
281 }
282 #[cfg(feature = "nvidia")]
283 #[test]
284 fn test_3d_2_cuda() {
285 test_3d_2::<crate::device::nvidia::Nvidia>();
286 }
287}