zenu_matrix/
matrix_iter.rs

1use crate::{
2    device::Device,
3    dim::{cal_offset, DimDyn, DimTrait},
4    matrix::{Matrix, Owned, Ref},
5    num::Num,
6    shape_stride::ShapeStride,
7};
8
9struct MapAxis<'a, T, F, D>
10where
11    T: Num,
12    F: FnMut(Matrix<Ref<&'a mut T>, DimDyn, D>),
13    D: Device,
14{
15    matrix: Matrix<Ref<&'a mut T>, DimDyn, D>,
16    axis: usize,
17    fn_map: F,
18}
19
20impl<'a, T, F, D> MapAxis<'a, T, F, D>
21where
22    T: Num,
23    F: FnMut(Matrix<Ref<&'a mut T>, DimDyn, D>),
24    D: Device,
25{
26    fn new(matrix: Matrix<Ref<&'a mut T>, DimDyn, D>, axis: usize, fn_map: F) -> Self {
27        Self {
28            matrix,
29            axis,
30            fn_map,
31        }
32    }
33
34    fn target_shape_stride(&self) -> ShapeStride<DimDyn> {
35        let sh = self.target_shape();
36        let st = self.target_stride();
37        ShapeStride::new(DimDyn::from([sh]), DimDyn::from([st]))
38    }
39
40    fn target_stride(&self) -> usize {
41        self.matrix.stride()[self.axis]
42    }
43
44    fn target_shape(&self) -> usize {
45        self.matrix.shape()[self.axis]
46    }
47
48    /// `fn_map`を適応する際に切り出す`Matrix`の一つ目の要素の`Index`の`Vec`を返す
49    fn get_index(&self) -> Vec<DimDyn> {
50        // axisのIndexを除いたIndexのVecを返す
51        let mut candidates = Vec::with_capacity(self.matrix.shape().len() - 1);
52        for (i, s) in self.matrix.shape().into_iter().enumerate() {
53            if i != self.axis {
54                candidates.push(s);
55            }
56        }
57
58        let combinations = generate_combinations(&candidates);
59        combinations
60            .into_iter()
61            .map(|c| {
62                let mut c = c;
63                c.insert(self.axis, 0);
64                DimDyn::from(c.as_slice())
65            })
66            .collect()
67    }
68
69    fn get_offsets(&self) -> Vec<usize> {
70        self.get_index()
71            .into_iter()
72            .map(|itm| cal_offset(self.matrix.stride(), itm))
73            .collect()
74    }
75
76    fn apply(&mut self) {
77        let shapt_stride = self.target_shape_stride();
78        for offset in self.get_offsets() {
79            // let m = self.matrix;
80            let view = self.matrix.offset_ptr_mut(offset);
81            let matrix = Matrix::new(view, shapt_stride.shape(), shapt_stride.stride());
82            (self.fn_map)(matrix);
83        }
84    }
85}
86
87fn generate_combinations(nums: &[usize]) -> Vec<Vec<usize>> {
88    fn recurse(
89        index: usize,
90        current: &mut Vec<usize>,
91        nums: &[usize],
92        result: &mut Vec<Vec<usize>>,
93    ) {
94        if index == nums.len() {
95            result.push(current.clone());
96            return;
97        }
98
99        for i in 0..nums[index] {
100            current.push(i);
101            recurse(index + 1, current, nums, result);
102            current.pop();
103        }
104    }
105
106    let mut result = Vec::new();
107    recurse(0, &mut Vec::new(), nums, &mut result);
108    result
109}
110
111pub trait MatrixIter<T: Num, D: Device> {
112    fn map_axis<F>(&self, axis: usize, fn_map: F) -> Matrix<Owned<T>, DimDyn, D>
113    where
114        F: FnMut(Matrix<Ref<&mut T>, DimDyn, D>);
115    fn map_axis_mut<F>(self, axis: usize, fn_map: F)
116    where
117        F: FnMut(Matrix<Ref<&mut T>, DimDyn, D>);
118}
119
120impl<T: Num, D: Device> MatrixIter<T, D> for Matrix<Ref<&mut T>, DimDyn, D> {
121    fn map_axis<F>(&self, axis: usize, fn_map: F) -> Matrix<Owned<T>, DimDyn, D>
122    where
123        F: FnMut(Matrix<Ref<&mut T>, DimDyn, D>),
124    {
125        let mut ans = Matrix::<_, DimDyn, D>::zeros(self.shape());
126        ans.to_ref_mut().copy_from(self);
127        ans.to_ref_mut().map_axis_mut(axis, fn_map);
128        ans
129    }
130
131    fn map_axis_mut<F>(self, axis: usize, fn_map: F)
132    where
133        F: FnMut(Matrix<Ref<&mut T>, DimDyn, D>),
134    {
135        // if self.shape().len() <= 1 {
136        //     panic!("shape.len() <= 1");
137        // }
138        assert!(self.shape().len() > 1, "shape.len() <= 1");
139        let mut map_axis = MapAxis::new(self, axis, fn_map);
140        map_axis.apply();
141    }
142}
143
144#[expect(clippy::float_cmp)]
145#[cfg(test)]
146mod map_axis {
147    use crate::{
148        device::Device,
149        dim::DimDyn,
150        matrix::{Matrix, Owned},
151        matrix_iter::MatrixIter,
152    };
153
154    fn test_2d_0<D: Device>() {
155        let mut a = Matrix::<Owned<f32>, DimDyn, D>::from_vec(vec![1., 2., 3., 4., 5., 6.], [2, 3]);
156        a.to_ref_mut().map_axis_mut(0, |m| {
157            let ans = Matrix::<Owned<f32>, DimDyn, D>::from_vec(vec![2., 1.], [2]);
158            m.copy_from(&ans);
159        });
160
161        let ans = Matrix::<Owned<f32>, DimDyn, D>::from_vec(vec![2., 2., 2., 1., 1., 1.], [2, 3]);
162        let diff = ans - a;
163        let diff = diff.asum();
164        assert_eq!(diff, 0.);
165    }
166    #[test]
167    fn test_2d_0_cpu() {
168        test_2d_0::<crate::device::cpu::Cpu>();
169    }
170    #[cfg(feature = "nvidia")]
171    #[test]
172    fn test_2d_0_cuda() {
173        test_2d_0::<crate::device::nvidia::Nvidia>();
174    }
175
176    fn test_2d_1<D: Device>() {
177        let mut a = Matrix::<Owned<f32>, DimDyn, D>::from_vec(vec![1., 2., 3., 4., 5., 6.], [2, 3]);
178        a.to_ref_mut().map_axis_mut(1, |m| {
179            let ans = Matrix::<Owned<f32>, DimDyn, D>::from_vec(vec![3., 2., 1.], [3]);
180            m.copy_from(&ans);
181        });
182
183        let ans = Matrix::<Owned<f32>, DimDyn, D>::from_vec(vec![3., 2., 1., 3., 2., 1.], [2, 3]);
184        let diff = ans - a;
185        let diff = diff.asum();
186        assert_eq!(diff, 0.);
187    }
188    #[test]
189    fn test_2d_1_cpu() {
190        test_2d_1::<crate::device::cpu::Cpu>();
191    }
192    #[cfg(feature = "nvidia")]
193    #[test]
194    fn test_2d_1_cuda() {
195        test_2d_1::<crate::device::nvidia::Nvidia>();
196    }
197
198    fn test_3d_0<D: Device>() {
199        let mut a = Matrix::<Owned<f32>, DimDyn, D>::from_vec(
200            vec![1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.],
201            [2, 2, 3],
202        );
203
204        a.to_ref_mut().map_axis_mut(0, |m| {
205            let ans = Matrix::<Owned<f32>, DimDyn, D>::from_vec(vec![2., 1.], [2]);
206            m.copy_from(&ans);
207        });
208
209        let ans = Matrix::<Owned<f32>, DimDyn, D>::from_vec(
210            vec![2., 2., 2., 2., 2., 2., 1., 1., 1., 1., 1., 1.],
211            [2, 2, 3],
212        );
213
214        let diff = ans - a;
215        let diff = diff.asum();
216        assert_eq!(diff, 0.);
217    }
218    #[test]
219    fn test_3d_0_cpu() {
220        test_3d_0::<crate::device::cpu::Cpu>();
221    }
222    #[cfg(feature = "nvidia")]
223    #[test]
224    fn test_3d_0_cuda() {
225        test_3d_0::<crate::device::nvidia::Nvidia>();
226    }
227
228    fn test_3d_1<D: Device>() {
229        let mut a = Matrix::<Owned<f32>, DimDyn, D>::from_vec(
230            vec![1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.],
231            [2, 2, 3],
232        );
233
234        a.to_ref_mut().map_axis_mut(1, |m| {
235            let ans = Matrix::<Owned<f32>, DimDyn, D>::from_vec(vec![2., 1.], [2]);
236            m.copy_from(&ans);
237        });
238
239        let ans = Matrix::<Owned<f32>, DimDyn, D>::from_vec(
240            vec![2., 2., 2., 1., 1., 1., 2., 2., 2., 1., 1., 1.],
241            [2, 2, 3],
242        );
243
244        let diff = ans - a;
245        let diff = diff.asum();
246        assert_eq!(diff, 0.);
247    }
248    #[test]
249    fn test_3d_1_cpu() {
250        test_3d_1::<crate::device::cpu::Cpu>();
251    }
252    #[cfg(feature = "nvidia")]
253    #[test]
254    fn test_3d_1_cuda() {
255        test_3d_1::<crate::device::nvidia::Nvidia>();
256    }
257
258    fn test_3d_2<D: Device>() {
259        let mut a = Matrix::<Owned<f32>, DimDyn, D>::from_vec(
260            vec![1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.],
261            [2, 2, 3],
262        );
263
264        a.to_ref_mut().map_axis_mut(2, |m| {
265            let ans = Matrix::<Owned<f32>, DimDyn, D>::from_vec(vec![3., 2., 1.], [3]);
266            m.copy_from(&ans);
267        });
268
269        let ans = Matrix::<Owned<f32>, DimDyn, D>::from_vec(
270            vec![3., 2., 1., 3., 2., 1., 3., 2., 1., 3., 2., 1.],
271            [2, 2, 3],
272        );
273
274        let diff = ans - a;
275        let diff = diff.asum();
276        assert_eq!(diff, 0.);
277    }
278    #[test]
279    fn test_3d_2_cpu() {
280        test_3d_2::<crate::device::cpu::Cpu>();
281    }
282    #[cfg(feature = "nvidia")]
283    #[test]
284    fn test_3d_2_cuda() {
285        test_3d_2::<crate::device::nvidia::Nvidia>();
286    }
287}