zenu_matrix/operation/
sum.rs

1use crate::{
2    device::Device,
3    dim::{DimDyn, DimTrait, LessDimTrait},
4    index::index_dyn_impl::Index,
5    matrix::{Matrix, Owned, Ref},
6    num::Num,
7};
8
9impl<T: Num, D: Device> Matrix<Ref<&T>, DimDyn, D> {
10    #[expect(clippy::missing_panics_doc)]
11    #[must_use]
12    pub fn sum(&self, axis: usize, keep_dim: bool) -> Matrix<Owned<T>, DimDyn, D> {
13        let shape = self.shape();
14        assert!(axis < shape.len(), "Invalid axis");
15
16        let result_shape = self.shape().remove_axis(axis);
17
18        let mut result = Matrix::zeros(result_shape);
19
20        for i in 0..shape[axis] {
21            let mut result_view_mut = result.to_ref_mut();
22            let s = self.clone();
23            let s = s.index_axis_dyn(Index::new(axis, i));
24            result_view_mut.add_assign(&s);
25        }
26
27        if keep_dim {
28            result.add_axis(axis);
29        }
30        result
31    }
32}
33
34#[expect(clippy::missing_panics_doc, clippy::needless_pass_by_value)]
35pub fn sum_to<T: Num, D: Device>(
36    source: Matrix<Ref<&T>, DimDyn, D>,
37    target: Matrix<Ref<&mut T>, DimDyn, D>,
38) {
39    assert!(
40        source.shape().len() >= target.shape().len(),
41        "source.shape().len() <= target.shape().len()"
42    );
43
44    let diff_len = source.shape().len() - target.shape().len();
45    if source.shape().slice() == target.shape().slice() {
46        let target = target;
47        target.copy_from(&source);
48        return;
49    }
50
51    if diff_len == 0 {
52        let mut diff_axis = Vec::new();
53        for (idx, (s, t)) in source
54            .shape()
55            .slice()
56            .iter()
57            .zip(target.shape().slice().iter())
58            .enumerate()
59        {
60            if *s == *t {
61                continue;
62            }
63            if *t == 1 {
64                diff_axis.push(idx);
65            } else {
66                panic!("hoge");
67            }
68        }
69
70        let mut tmp = source.new_matrix();
71        for axis in diff_axis {
72            let tmp_sum = {
73                let tmp_ref = tmp.to_ref();
74                tmp_ref.sum(axis, true)
75            };
76            tmp = tmp_sum;
77        }
78        target.copy_from(&tmp);
79        return;
80    }
81
82    assert!(
83        source.shape().is_include(target.shape()),
84        "!source.shape().is_include(target.shape())"
85    );
86
87    if diff_len == 1 {
88        let target = target;
89        let ans = source.sum(0, false);
90        target.copy_from(&ans);
91    } else {
92        sum_to(source.sum(0, false).to_ref(), target);
93    }
94}
95
96#[cfg(test)]
97mod sum {
98    #![expect(
99        clippy::float_cmp,
100        clippy::cast_precision_loss,
101        clippy::cast_possible_truncation
102    )]
103    use crate::{
104        device::Device,
105        dim::{DimDyn, DimTrait},
106        matrix::{Matrix, Owned},
107    };
108
109    fn test_4d<D: Device>() {
110        let mut source_vec = Vec::new();
111        for i in 0..2 * 3 * 4 * 5 {
112            source_vec.push(i as f32);
113        }
114        let source: Matrix<Owned<f32>, DimDyn, D> = Matrix::from_vec(source_vec, [2, 3, 4, 5]);
115
116        let sum_0 = source.to_ref().sum(0, false);
117        let sum_1 = source.to_ref().sum(1, false);
118        let sum_2 = source.to_ref().sum(2, false);
119        let sum_3 = source.to_ref().sum(3, false);
120
121        assert_eq!(sum_0.shape().slice(), [3, 4, 5]);
122        assert_eq!(sum_1.shape().slice(), [2, 4, 5]);
123        assert_eq!(sum_2.shape().slice(), [2, 3, 5]);
124        assert_eq!(sum_3.shape().slice(), [2, 3, 4]);
125
126        let mut ans_vec_0 = Vec::new();
127        for i in 60..=178 {
128            if i % 2 == 0 {
129                ans_vec_0.push(i as f32);
130            }
131        }
132        let ans_0: Matrix<_, DimDyn, _> = Matrix::from_vec(ans_vec_0, [3, 4, 5]);
133        let diff = sum_0.to_ref() - ans_0.to_ref();
134        let diff_sum = diff.asum();
135        assert!(diff_sum < 1e-6);
136
137        let ans_vec_1 = vec![
138            60, 63, 66, 69, 72, 75, 78, 81, 84, 87, 90, 93, 96, 99, 102, 105, 108, 111, 114, 117,
139            240, 243, 246, 249, 252, 255, 258, 261, 264, 267, 270, 273, 276, 279, 282, 285, 288,
140            291, 294, 297,
141        ];
142        let nas_vec_1 = ans_vec_1.into_iter().map(|x| x as f32).collect();
143        let ans_1: Matrix<_, DimDyn, _> = Matrix::from_vec(nas_vec_1, [2, 4, 5]);
144        let diff = sum_1.to_ref() - ans_1.to_ref();
145        let diff_sum = diff.asum();
146        assert!(diff_sum < 1e-6);
147
148        let ans_vec_2 = vec![
149            30, 34, 38, 42, 46, 110, 114, 118, 122, 126, 190, 194, 198, 202, 206, 270, 274, 278,
150            282, 286, 350, 354, 358, 362, 366, 430, 434, 438, 442, 446,
151        ];
152        let nas_vec_2 = ans_vec_2.into_iter().map(|x| x as f32).collect();
153        let ans_2 = Matrix::<_, DimDyn, _>::from_vec(nas_vec_2, [2, 3, 5]);
154        let diff = sum_2.to_ref() - ans_2.to_ref();
155        let diff_sum = diff.asum();
156        assert!(diff_sum < 1e-6);
157
158        let ans_vec_3 = vec![
159            10, 35, 60, 85, 110, 135, 160, 185, 210, 235, 260, 285, 310, 335, 360, 385, 410, 435,
160            460, 485, 510, 535, 560, 585,
161        ];
162        let nas_vec_3 = ans_vec_3.into_iter().map(|x| x as f32).collect();
163        let ans_3 = Matrix::<_, DimDyn, _>::from_vec(nas_vec_3, [2, 3, 4]);
164        let diff = sum_3.to_ref() - ans_3.to_ref();
165        let diff_sum = diff.asum();
166        assert!(diff_sum < 1e-6);
167    }
168    #[test]
169    fn test_4d_cpu() {
170        test_4d::<crate::device::cpu::Cpu>();
171    }
172    #[cfg(feature = "nvidia")]
173    #[test]
174    fn test_4d_gpu() {
175        test_4d::<crate::device::nvidia::Nvidia>();
176    }
177
178    fn test_4d_keep_dim<D: Device>() {
179        let mut source_vec = Vec::new();
180        for i in 0..2 * 3 * 4 * 5 {
181            source_vec.push(i as f32);
182        }
183        let source = Matrix::<_, DimDyn, D>::from_vec(source_vec, [2, 3, 4, 5]);
184
185        let sum_0 = source.to_ref().sum(0, true);
186        let sum_1 = source.to_ref().sum(1, true);
187        let sum_2 = source.to_ref().sum(2, true);
188        let sum_3 = source.to_ref().sum(3, true);
189
190        assert_eq!(sum_0.shape().slice(), [1, 3, 4, 5]);
191        assert_eq!(sum_1.shape().slice(), [2, 1, 4, 5]);
192        assert_eq!(sum_2.shape().slice(), [2, 3, 1, 5]);
193        assert_eq!(sum_3.shape().slice(), [2, 3, 4, 1]);
194
195        let mut ans_vec_0 = Vec::new();
196        for i in 60..=178 {
197            if i % 2 == 0 {
198                ans_vec_0.push(i as f32);
199            }
200        }
201        let ans_0 = Matrix::<_, DimDyn, D>::from_vec(ans_vec_0, [1, 3, 4, 5]);
202        let diff = sum_0.to_ref() - ans_0.to_ref();
203        let diff_sum = diff.asum();
204        assert!(diff_sum < 1e-6);
205
206        let ans_vec_1 = vec![
207            60, 63, 66, 69, 72, 75, 78, 81, 84, 87, 90, 93, 96, 99, 102, 105, 108, 111, 114, 117,
208            240, 243, 246, 249, 252, 255, 258, 261, 264, 267, 270, 273, 276, 279, 282, 285, 288,
209            291, 294, 297,
210        ];
211        let nas_vec_1 = ans_vec_1.into_iter().map(|x| x as f32).collect();
212        let ans_1 = Matrix::<_, DimDyn, D>::from_vec(nas_vec_1, [2, 1, 4, 5]);
213        let diff = sum_1.to_ref() - ans_1.to_ref();
214        let diff_sum = diff.asum();
215        assert!(diff_sum < 1e-6);
216
217        let ans_vec_2 = vec![
218            30, 34, 38, 42, 46, 110, 114, 118, 122, 126, 190, 194, 198, 202, 206, 270, 274, 278,
219            282, 286, 350, 354, 358, 362, 366, 430, 434, 438, 442, 446,
220        ];
221        let nas_vec_2 = ans_vec_2.into_iter().map(|x| x as f32).collect();
222        let ans_2 = Matrix::<_, DimDyn, D>::from_vec(nas_vec_2, [2, 3, 1, 5]);
223        let diff = sum_2.to_ref() - ans_2.to_ref();
224        let diff_sum = diff.asum();
225        assert!(diff_sum < 1e-6);
226
227        let ans_vec_3 = vec![
228            10, 35, 60, 85, 110, 135, 160, 185, 210, 235, 260, 285, 310, 335, 360, 385, 410, 435,
229            460, 485, 510, 535, 560, 585,
230        ];
231        let nas_vec_3 = ans_vec_3.into_iter().map(|x| x as f32).collect();
232        let ans_3 = Matrix::<_, DimDyn, D>::from_vec(nas_vec_3, [2, 3, 4, 1]);
233        let diff = sum_3.to_ref() - ans_3.to_ref();
234        let diff_sum = diff.asum();
235        assert!(diff_sum < 1e-6);
236    }
237    #[test]
238    fn test_4d_keep_dim_cpu() {
239        test_4d_keep_dim::<crate::device::cpu::Cpu>();
240    }
241    #[cfg(feature = "nvidia")]
242    #[test]
243    fn test_4d_keep_dim_gpu() {
244        test_4d_keep_dim::<crate::device::nvidia::Nvidia>();
245    }
246}