1use crate::{
2 device::Device,
3 dim::{DimDyn, DimTrait, LessDimTrait},
4 index::index_dyn_impl::Index,
5 matrix::{Matrix, Owned, Ref},
6 num::Num,
7};
8
9impl<T: Num, D: Device> Matrix<Ref<&T>, DimDyn, D> {
10 #[expect(clippy::missing_panics_doc)]
11 #[must_use]
12 pub fn sum(&self, axis: usize, keep_dim: bool) -> Matrix<Owned<T>, DimDyn, D> {
13 let shape = self.shape();
14 assert!(axis < shape.len(), "Invalid axis");
15
16 let result_shape = self.shape().remove_axis(axis);
17
18 let mut result = Matrix::zeros(result_shape);
19
20 for i in 0..shape[axis] {
21 let mut result_view_mut = result.to_ref_mut();
22 let s = self.clone();
23 let s = s.index_axis_dyn(Index::new(axis, i));
24 result_view_mut.add_assign(&s);
25 }
26
27 if keep_dim {
28 result.add_axis(axis);
29 }
30 result
31 }
32}
33
34#[expect(clippy::missing_panics_doc, clippy::needless_pass_by_value)]
35pub fn sum_to<T: Num, D: Device>(
36 source: Matrix<Ref<&T>, DimDyn, D>,
37 target: Matrix<Ref<&mut T>, DimDyn, D>,
38) {
39 assert!(
40 source.shape().len() >= target.shape().len(),
41 "source.shape().len() <= target.shape().len()"
42 );
43
44 let diff_len = source.shape().len() - target.shape().len();
45 if source.shape().slice() == target.shape().slice() {
46 let target = target;
47 target.copy_from(&source);
48 return;
49 }
50
51 if diff_len == 0 {
52 let mut diff_axis = Vec::new();
53 for (idx, (s, t)) in source
54 .shape()
55 .slice()
56 .iter()
57 .zip(target.shape().slice().iter())
58 .enumerate()
59 {
60 if *s == *t {
61 continue;
62 }
63 if *t == 1 {
64 diff_axis.push(idx);
65 } else {
66 panic!("hoge");
67 }
68 }
69
70 let mut tmp = source.new_matrix();
71 for axis in diff_axis {
72 let tmp_sum = {
73 let tmp_ref = tmp.to_ref();
74 tmp_ref.sum(axis, true)
75 };
76 tmp = tmp_sum;
77 }
78 target.copy_from(&tmp);
79 return;
80 }
81
82 assert!(
83 source.shape().is_include(target.shape()),
84 "!source.shape().is_include(target.shape())"
85 );
86
87 if diff_len == 1 {
88 let target = target;
89 let ans = source.sum(0, false);
90 target.copy_from(&ans);
91 } else {
92 sum_to(source.sum(0, false).to_ref(), target);
93 }
94}
95
96#[cfg(test)]
97mod sum {
98 #![expect(
99 clippy::float_cmp,
100 clippy::cast_precision_loss,
101 clippy::cast_possible_truncation
102 )]
103 use crate::{
104 device::Device,
105 dim::{DimDyn, DimTrait},
106 matrix::{Matrix, Owned},
107 };
108
109 fn test_4d<D: Device>() {
110 let mut source_vec = Vec::new();
111 for i in 0..2 * 3 * 4 * 5 {
112 source_vec.push(i as f32);
113 }
114 let source: Matrix<Owned<f32>, DimDyn, D> = Matrix::from_vec(source_vec, [2, 3, 4, 5]);
115
116 let sum_0 = source.to_ref().sum(0, false);
117 let sum_1 = source.to_ref().sum(1, false);
118 let sum_2 = source.to_ref().sum(2, false);
119 let sum_3 = source.to_ref().sum(3, false);
120
121 assert_eq!(sum_0.shape().slice(), [3, 4, 5]);
122 assert_eq!(sum_1.shape().slice(), [2, 4, 5]);
123 assert_eq!(sum_2.shape().slice(), [2, 3, 5]);
124 assert_eq!(sum_3.shape().slice(), [2, 3, 4]);
125
126 let mut ans_vec_0 = Vec::new();
127 for i in 60..=178 {
128 if i % 2 == 0 {
129 ans_vec_0.push(i as f32);
130 }
131 }
132 let ans_0: Matrix<_, DimDyn, _> = Matrix::from_vec(ans_vec_0, [3, 4, 5]);
133 let diff = sum_0.to_ref() - ans_0.to_ref();
134 let diff_sum = diff.asum();
135 assert!(diff_sum < 1e-6);
136
137 let ans_vec_1 = vec![
138 60, 63, 66, 69, 72, 75, 78, 81, 84, 87, 90, 93, 96, 99, 102, 105, 108, 111, 114, 117,
139 240, 243, 246, 249, 252, 255, 258, 261, 264, 267, 270, 273, 276, 279, 282, 285, 288,
140 291, 294, 297,
141 ];
142 let nas_vec_1 = ans_vec_1.into_iter().map(|x| x as f32).collect();
143 let ans_1: Matrix<_, DimDyn, _> = Matrix::from_vec(nas_vec_1, [2, 4, 5]);
144 let diff = sum_1.to_ref() - ans_1.to_ref();
145 let diff_sum = diff.asum();
146 assert!(diff_sum < 1e-6);
147
148 let ans_vec_2 = vec![
149 30, 34, 38, 42, 46, 110, 114, 118, 122, 126, 190, 194, 198, 202, 206, 270, 274, 278,
150 282, 286, 350, 354, 358, 362, 366, 430, 434, 438, 442, 446,
151 ];
152 let nas_vec_2 = ans_vec_2.into_iter().map(|x| x as f32).collect();
153 let ans_2 = Matrix::<_, DimDyn, _>::from_vec(nas_vec_2, [2, 3, 5]);
154 let diff = sum_2.to_ref() - ans_2.to_ref();
155 let diff_sum = diff.asum();
156 assert!(diff_sum < 1e-6);
157
158 let ans_vec_3 = vec![
159 10, 35, 60, 85, 110, 135, 160, 185, 210, 235, 260, 285, 310, 335, 360, 385, 410, 435,
160 460, 485, 510, 535, 560, 585,
161 ];
162 let nas_vec_3 = ans_vec_3.into_iter().map(|x| x as f32).collect();
163 let ans_3 = Matrix::<_, DimDyn, _>::from_vec(nas_vec_3, [2, 3, 4]);
164 let diff = sum_3.to_ref() - ans_3.to_ref();
165 let diff_sum = diff.asum();
166 assert!(diff_sum < 1e-6);
167 }
168 #[test]
169 fn test_4d_cpu() {
170 test_4d::<crate::device::cpu::Cpu>();
171 }
172 #[cfg(feature = "nvidia")]
173 #[test]
174 fn test_4d_gpu() {
175 test_4d::<crate::device::nvidia::Nvidia>();
176 }
177
178 fn test_4d_keep_dim<D: Device>() {
179 let mut source_vec = Vec::new();
180 for i in 0..2 * 3 * 4 * 5 {
181 source_vec.push(i as f32);
182 }
183 let source = Matrix::<_, DimDyn, D>::from_vec(source_vec, [2, 3, 4, 5]);
184
185 let sum_0 = source.to_ref().sum(0, true);
186 let sum_1 = source.to_ref().sum(1, true);
187 let sum_2 = source.to_ref().sum(2, true);
188 let sum_3 = source.to_ref().sum(3, true);
189
190 assert_eq!(sum_0.shape().slice(), [1, 3, 4, 5]);
191 assert_eq!(sum_1.shape().slice(), [2, 1, 4, 5]);
192 assert_eq!(sum_2.shape().slice(), [2, 3, 1, 5]);
193 assert_eq!(sum_3.shape().slice(), [2, 3, 4, 1]);
194
195 let mut ans_vec_0 = Vec::new();
196 for i in 60..=178 {
197 if i % 2 == 0 {
198 ans_vec_0.push(i as f32);
199 }
200 }
201 let ans_0 = Matrix::<_, DimDyn, D>::from_vec(ans_vec_0, [1, 3, 4, 5]);
202 let diff = sum_0.to_ref() - ans_0.to_ref();
203 let diff_sum = diff.asum();
204 assert!(diff_sum < 1e-6);
205
206 let ans_vec_1 = vec![
207 60, 63, 66, 69, 72, 75, 78, 81, 84, 87, 90, 93, 96, 99, 102, 105, 108, 111, 114, 117,
208 240, 243, 246, 249, 252, 255, 258, 261, 264, 267, 270, 273, 276, 279, 282, 285, 288,
209 291, 294, 297,
210 ];
211 let nas_vec_1 = ans_vec_1.into_iter().map(|x| x as f32).collect();
212 let ans_1 = Matrix::<_, DimDyn, D>::from_vec(nas_vec_1, [2, 1, 4, 5]);
213 let diff = sum_1.to_ref() - ans_1.to_ref();
214 let diff_sum = diff.asum();
215 assert!(diff_sum < 1e-6);
216
217 let ans_vec_2 = vec![
218 30, 34, 38, 42, 46, 110, 114, 118, 122, 126, 190, 194, 198, 202, 206, 270, 274, 278,
219 282, 286, 350, 354, 358, 362, 366, 430, 434, 438, 442, 446,
220 ];
221 let nas_vec_2 = ans_vec_2.into_iter().map(|x| x as f32).collect();
222 let ans_2 = Matrix::<_, DimDyn, D>::from_vec(nas_vec_2, [2, 3, 1, 5]);
223 let diff = sum_2.to_ref() - ans_2.to_ref();
224 let diff_sum = diff.asum();
225 assert!(diff_sum < 1e-6);
226
227 let ans_vec_3 = vec![
228 10, 35, 60, 85, 110, 135, 160, 185, 210, 235, 260, 285, 310, 335, 360, 385, 410, 435,
229 460, 485, 510, 535, 560, 585,
230 ];
231 let nas_vec_3 = ans_vec_3.into_iter().map(|x| x as f32).collect();
232 let ans_3 = Matrix::<_, DimDyn, D>::from_vec(nas_vec_3, [2, 3, 4, 1]);
233 let diff = sum_3.to_ref() - ans_3.to_ref();
234 let diff_sum = diff.asum();
235 assert!(diff_sum < 1e-6);
236 }
237 #[test]
238 fn test_4d_keep_dim_cpu() {
239 test_4d_keep_dim::<crate::device::cpu::Cpu>();
240 }
241 #[cfg(feature = "nvidia")]
242 #[test]
243 fn test_4d_keep_dim_gpu() {
244 test_4d_keep_dim::<crate::device::nvidia::Nvidia>();
245 }
246}